From 9a7c767bb0ee28bcf36b4a139c7584635ba85270 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Mon, 22 Jul 2024 12:13:58 +0100 Subject: [PATCH 1/6] chore: start model tracing --- Cargo.toml | 2 +- crates/ratchet-core/Cargo.toml | 2 +- crates/ratchet-core/src/compiled_op.rs | 2 + crates/ratchet-core/src/executable.rs | 6 +- crates/ratchet-core/src/gpu/device.rs | 1 + crates/ratchet-core/src/op.rs | 3 + crates/ratchet-models/tests/whisper.rs | 124 ++----------------------- package.json | 2 +- pnpm-lock.yaml | 8 +- 9 files changed, 25 insertions(+), 125 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5a0de92e..e8dc228d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ encase = { git = "https://github.com/cwfitzgerald/encase", branch = "add-member" env_logger = "0.11.3" fern = "0.6.2" getrandom = "0.2" -glam = "0.27.0" +glam = "0.28.0" globwalk = "0.8.1" gloo-net = { version = "0.5.0", default-features = false } hound = "3.5.1" diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 3a7e601a..84838ae9 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [features] -default = ["rand", "testing"] +default = ["rand", "testing", "debug"] gpu-profiling = ["dep:tabled", "dep:itertools"] rand = ["dep:rand", "dep:rand_distr"] plotting = ["dep:dot3", "dep:tempfile"] diff --git a/crates/ratchet-core/src/compiled_op.rs b/crates/ratchet-core/src/compiled_op.rs index 788ac37c..16c32fbf 100644 --- a/crates/ratchet-core/src/compiled_op.rs +++ b/crates/ratchet-core/src/compiled_op.rs @@ -4,6 +4,8 @@ use crate::gpu::{ }; use crate::{drvec, rvec, KernelKey, OperationError, RVec, Tensor}; use derive_new::new; +#[cfg(feature = "debug")] +use std::sync::Arc; use wgpu::DynamicOffset; //Compiled op represents a single kernel invocation diff --git a/crates/ratchet-core/src/executable.rs b/crates/ratchet-core/src/executable.rs index 11d9625a..488bd9d5 100644 --- a/crates/ratchet-core/src/executable.rs +++ b/crates/ratchet-core/src/executable.rs @@ -1,9 +1,13 @@ use crate::gpu::{GpuUniform, PoolError, StaticResourcePoolAccessor, WgpuDevice}; use crate::CompiledOp; use derive_new::new; +use wgpu::SubmissionIndex; + #[cfg(not(feature = "debug"))] use std::marker::PhantomData; -use wgpu::SubmissionIndex; + +#[cfg(feature = "debug")] +use crate::Tensor; /// # Executable /// diff --git a/crates/ratchet-core/src/gpu/device.rs b/crates/ratchet-core/src/gpu/device.rs index b9d2ce18..3dc053ce 100644 --- a/crates/ratchet-core/src/gpu/device.rs +++ b/crates/ratchet-core/src/gpu/device.rs @@ -93,6 +93,7 @@ impl WgpuDevice { log::warn!("Forcing F32 precision"); features.SHADER_F16 = false; } + features.SHADER_F16 = false; if std::env::var("RATCHET_DISABLE_SUBGROUPS").is_ok() { log::warn!("Disabling subgroup support"); diff --git a/crates/ratchet-core/src/op.rs b/crates/ratchet-core/src/op.rs index 475070de..0fc54a01 100644 --- a/crates/ratchet-core/src/op.rs +++ b/crates/ratchet-core/src/op.rs @@ -9,6 +9,9 @@ use crate::{ use std::borrow::Cow; use std::fmt::Debug; +#[cfg(feature = "debug")] +use {crate::gpu::BufferUsagesExt, std::sync::Arc}; + #[derive(Clone, Debug)] #[non_exhaustive] pub enum LazyOp { diff --git a/crates/ratchet-models/tests/whisper.rs b/crates/ratchet-models/tests/whisper.rs index 0bffc906..1392f498 100644 --- a/crates/ratchet-models/tests/whisper.rs +++ b/crates/ratchet-models/tests/whisper.rs @@ -18,118 +18,6 @@ fn log_init() { console_log::init_with_level(log::Level::Debug).unwrap(); } -/* -#[wasm_bindgen_test] -async fn distil_large_v3_encoder() -> Result<(), JsValue> { - log_init(); - let model_repo = - ApiBuilder::from_hf("FL33TW00D-HF/distil-whisper-large-v3", RepoType::Model).build(); - let model_data = model_repo.get("distil-large-v3_q8_0.gguf").await?; - let config_data = model_repo.get("config.json").await?; - - let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build(); - let input_npy = ground_repo.get("distil_large_v3_q80_mm0_input.npy").await?; - let ground_npy = ground_repo.get("distil_large_v3_q80_mm0_hs.npy").await?; - - let mut reader = std::io::BufReader::new(std::io::Cursor::new(model_data.to_vec())); - let header = gguf::Header::read(&mut reader).unwrap(); - let config: Config = serde_json::from_slice(&config_data.to_vec()).unwrap(); - - let device = Device::request_device(DeviceRequest::GPU).await.unwrap(); - - let input_data = &input_npy.to_vec(); - let input = Tensor::from_npy_bytes::(input_data, &device).unwrap(); - let ground = Tensor::from_npy_bytes::(&ground_npy.to_vec(), &Device::CPU).unwrap(); - - let encoder = WhisperEncoder::load(&header, &config, &mut reader, &device).unwrap(); - let result = encoder - .schedule(input) - .unwrap() - .full() - .unwrap() - .resolve() - .unwrap(); - let ours = result.to(&Device::CPU).await.unwrap(); - ground.all_close(&ours, 1e-3, 1e-3).unwrap(); - Ok(()) -}*/ - -/* -#[wasm_bindgen_test] -async fn distil_large_v3_decoder() -> Result<(), JsValue> { - log_init(); - let model_repo = - ApiBuilder::from_hf("FL33TW00D-HF/distil-whisper-large-v3", RepoType::Model).build(); - let model_data = model_repo.get("distil-large-v3_q8_0.gguf").await?; - let config_data = model_repo.get("config.json").await?; - - let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build(); - let hs_data = ground_repo.get("distil_large_v3_q80_mm0_hs.npy").await?; - - let mut reader = std::io::BufReader::new(std::io::Cursor::new(model_data.to_vec())); - let header = gguf::Header::read(&mut reader).unwrap(); - let config: Config = serde_json::from_slice(&config_data.to_vec()).unwrap(); - - let device = Device::request_device(DeviceRequest::GPU).await.unwrap(); - let audio_ctx = Tensor::from_npy_bytes::(&hs_data.to_vec(), &device) - .unwrap() - .half() - .unwrap() - .resolve() - .unwrap(); - log::debug!("Audio Context: {:?}", audio_ctx); - let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device).unwrap(); - - let mut tokens = vec![50258, 50259, 50360]; - let mut all_tokens = tokens.clone(); - let mut all_logits = vec![]; - while tokens[tokens.len() - 1] != 50257 { - let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); - let result = decoder - .schedule([audio_ctx.clone(), token_t]) - .unwrap() - .resolve_debug() - .unwrap(); - - let our_logits = result.to(&Device::CPU).await.unwrap(); - all_logits.push(our_logits.clone()); - let nd_logits = our_logits.to_ndarray_view::(); - log::info!("Logits: {:?}", nd_logits); - - let sliced = nd_logits.slice(s![.., -1.., ..51866]).remove_axis(Axis(1)); - decoder.cache_mut().update(tokens.len()); - - tokens = sliced - .map_axis(Axis(1), |row| row.argmax_skipnan().unwrap()) - .iter() - .map(|&x| x as i32) - .collect::>(); - println!("Token: {:?}", tokens); - panic!(); - all_tokens.extend(tokens.clone()); - } - - let ground_tokens = vec![ - 50258, 50259, 50360, 50365, 639, 307, 264, 4532, 3479, 587, 11, 15578, 264, 881, 2062, 847, - 11, 34674, 5932, 30340, 295, 3123, 4397, 608, 1652, 13, 50517, 50530, 6947, 472, 575, - 12023, 4365, 11, 20899, 11, 10445, 11, 18356, 11, 4225, 4782, 11, 50624, 50626, 1804, 4651, - 3123, 4397, 34922, 8963, 862, 6352, 13, 50695, 50701, 821, 311, 257, 3804, 5214, 11, 2610, - 5214, 11, 6383, 11, 2643, 5214, 11, 293, 544, 13, 50797, 50807, 10246, 8963, 2436, 2965, - 281, 747, 604, 1081, 13, 50875, 50881, 400, 456, 366, 867, 34674, 862, 365, 11, 293, 1184, - 472, 1487, 365, 1080, 1065, 2121, 11377, 11, 51002, 51007, 4532, 3479, 5864, 293, 1019, 11, - 5456, 4122, 300, 30686, 25038, 1286, 13, 51120, 51135, 30062, 264, 6582, 29814, 412, 264, - 10155, 11, 1849, 1426, 11, 587, 11, 264, 3874, 34544, 412, 264, 7267, 3096, 13, 51243, - 51246, 18463, 428, 1032, 412, 264, 1032, 5675, 13, 51287, 51290, 30062, 264, 16629, 7283, - 13, 51320, 51328, 400, 613, 862, 6352, 3318, 1214, 281, 1254, 257, 3123, 4397, 34922, 8963, - 1081, 6352, 11, 370, 27985, 11, 51504, 51504, 370, 6239, 11, 370, 44078, 1688, 356, 9942, - 11, 291, 603, 528, 281, 8963, 552, 439, 13, 51623, 51635, 25642, 12089, 1652, 366, 1027, - 14759, 490, 7336, 836, 65, 13, 51743, 51743, 440, 4356, 436, 366, 11, 264, 1101, 436, 366, - 13, 51834, - ]; - assert_eq!(all_tokens, ground_tokens); - Ok(()) -}*/ - /* #[wasm_bindgen_test] async fn tiny_encoder() -> Result<(), JsValue> { @@ -164,7 +52,7 @@ async fn tiny_encoder() -> Result<(), JsValue> { async fn tiny_decoder() -> Result<(), JsValue> { log_init(); let model_repo = ApiBuilder::from_hf("FL33TW00D-HF/whisper-tiny", RepoType::Model).build(); - let model_data = model_repo.get("tiny_f32.gguf").await?; + let model_data = model_repo.get("tiny_q8_0.gguf").await?; let config_data = model_repo.get("config.json").await?; let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build(); @@ -176,20 +64,22 @@ async fn tiny_decoder() -> Result<(), JsValue> { let device = Device::request_device(DeviceRequest::GPU).await.unwrap(); - let audio_ctx = Tensor::from_npy_bytes::(&hs_data.to_vec(), &device)? - .cast(device.compute_precision())?; + let audio_ctx = Tensor::from_npy_bytes::(&hs_data.to_vec(), &device) + .unwrap() + .cast(device.compute_precision()) + .unwrap(); let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device).unwrap(); let mut tokens = vec![50258, 50259, 50359]; let mut all_tokens = tokens.clone(); let mut all_logits = vec![]; - let vocab_size = 51866; + let vocab_size = 51865; while tokens[tokens.len() - 1] != 50257 { let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); let result = decoder .schedule([audio_ctx.clone(), token_t]) .unwrap() - .resolve() + .resolve_debug() .unwrap(); let our_logits = result.to(&Device::CPU).await.unwrap(); diff --git a/package.json b/package.json index 6cbe15d0..f64c93c4 100644 --- a/package.json +++ b/package.json @@ -5,6 +5,6 @@ "private": true, "devDependencies": { "pkg-pr-new": "0.0.15", - "wasm-pack": "0.12.1" + "wasm-pack": "0.13.0" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index ba57a17d..d7f35fdd 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -12,8 +12,8 @@ importers: specifier: 0.0.15 version: 0.0.15 wasm-pack: - specifier: 0.12.1 - version: 0.12.1 + specifier: 0.13.0 + version: 0.13.0 examples/ratchet-phi: dependencies: @@ -1608,8 +1608,8 @@ packages: engines: {node: ^14.17.0 || ^16.13.0 || >=18.0.0} dev: true - /wasm-pack@0.12.1: - resolution: {integrity: sha512-dIyKWUumPFsGohdndZjDXRFaokUT/kQS+SavbbiXVAvA/eN4riX5QNdB6AhXQx37zNxluxQkuixZUgJ8adKjOg==} + /wasm-pack@0.13.0: + resolution: {integrity: sha512-AmboGZEnZoIcVCzSlkLEmNFEqJN+IwgshJ5S7pi30uNUTce4LvWkifQzsQRxnWj47G8gkqZxlyGlyQplsnIS7w==} hasBin: true requiresBuild: true dependencies: From a7e733eac09ddff6ffba2beccfa98eac1ac6795f Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Mon, 22 Jul 2024 13:02:43 +0100 Subject: [PATCH 2/6] chore: start model tracing --- crates/ratchet-core/Cargo.toml | 8 +- crates/ratchet-core/src/compiled_op.rs | 6 +- crates/ratchet-core/src/executable.rs | 58 +++++------- crates/ratchet-core/src/op.rs | 25 ++--- crates/ratchet-core/src/tensor.rs | 98 +++++++++----------- crates/ratchet-models/src/whisper/decoder.rs | 2 +- 6 files changed, 88 insertions(+), 109 deletions(-) diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 84838ae9..6c56c7d2 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -4,13 +4,13 @@ version = "0.1.0" edition = "2021" [features] -default = ["rand", "testing", "debug"] +default = ["rand", "testing", "trace"] gpu-profiling = ["dep:tabled", "dep:itertools"] rand = ["dep:rand", "dep:rand_distr"] plotting = ["dep:dot3", "dep:tempfile"] testing = ["dep:npyz", "dep:ndarray"] pyo3 = ["dep:pyo3", "dep:numpy", "dep:regex"] -debug = [] #dump every node +trace = ["dep:uuid"] [build-dependencies] tera = { workspace = true } @@ -61,6 +61,10 @@ pyo3 = { workspace = true, features = ["auto-initialize"], optional = true } regex = { workspace = true, optional = true } numpy = { workspace = true, optional = true, features=["half"]} +# Trace +uuid = { version="1.1.0", features = ["v4", "fast-rng"], optional = true } + + [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen.workspace = true futures-intrusive.workspace = true diff --git a/crates/ratchet-core/src/compiled_op.rs b/crates/ratchet-core/src/compiled_op.rs index 16c32fbf..878a77c5 100644 --- a/crates/ratchet-core/src/compiled_op.rs +++ b/crates/ratchet-core/src/compiled_op.rs @@ -4,7 +4,7 @@ use crate::gpu::{ }; use crate::{drvec, rvec, KernelKey, OperationError, RVec, Tensor}; use derive_new::new; -#[cfg(feature = "debug")] +#[cfg(feature = "trace")] use std::sync::Arc; use wgpu::DynamicOffset; @@ -18,8 +18,8 @@ pub struct CompiledOp { storage_groups: RVec, offset: DynamicOffset, //offset into the metadata uniform buffer pub kernel_key: KernelKey, - #[cfg(feature = "debug")] - pub debug_buffer: Option>, + #[cfg(feature = "trace")] + pub trace_buffer: Option>, } impl CompiledOp { diff --git a/crates/ratchet-core/src/executable.rs b/crates/ratchet-core/src/executable.rs index 488bd9d5..98d154be 100644 --- a/crates/ratchet-core/src/executable.rs +++ b/crates/ratchet-core/src/executable.rs @@ -3,10 +3,10 @@ use crate::CompiledOp; use derive_new::new; use wgpu::SubmissionIndex; -#[cfg(not(feature = "debug"))] +#[cfg(not(feature = "trace"))] use std::marker::PhantomData; -#[cfg(feature = "debug")] +#[cfg(feature = "trace")] use crate::Tensor; /// # Executable @@ -17,9 +17,9 @@ use crate::Tensor; pub struct Executable<'t> { steps: Vec, gpu_uniform: GpuUniform, - #[cfg(feature = "debug")] - debug_list: Vec<&'t Tensor>, - #[cfg(not(feature = "debug"))] + #[cfg(feature = "trace")] + trace_list: Vec<&'t Tensor>, + #[cfg(not(feature = "trace"))] _phantom: PhantomData<&'t ()>, } @@ -32,6 +32,21 @@ pub enum ExecutionError { DebuggingError(&'static str), } +impl Executable<'_> { + pub fn steps(&self) -> &[CompiledOp] { + &self.steps + } + + pub fn gpu_uniform(&self) -> &GpuUniform { + &self.gpu_uniform + } + + #[cfg(feature = "trace")] + pub fn debug_list(&self) -> &[&Tensor] { + &self.trace_list + } +} + impl Executable<'_> { #[cfg(not(feature = "gpu-profiling"))] pub fn dispatch(&self, device: &WgpuDevice) -> Result { @@ -62,15 +77,13 @@ impl Executable<'_> { Ok(device.queue().submit(Some(encoder.finish()))) } - #[cfg(feature = "debug")] - pub(crate) fn dispatch_debugging( + #[cfg(feature = "trace")] + pub(crate) fn dispatch_trace( &self, device: &WgpuDevice, ) -> Result { - use crate::{wgpu_buffer_to_cpu_buffer, DeviceStorage}; - let pipeline_resources = device.pipeline_resources(); - assert!(self.debug_list.len() == self.steps.len()); + assert!(self.trace_list.len() == self.steps.len()); let mut last_index = None; for (step_index, step) in self.steps.iter().enumerate() { @@ -96,7 +109,7 @@ impl Executable<'_> { cpass.dispatch_workgroups(x_count, y_count, z_count); } - let result_t = self.debug_list[step_index].clone(); + let result_t = self.trace_list[step_index].clone(); let gpu_storage = result_t.storage(); let result_buf = &gpu_storage .as_ref() @@ -106,7 +119,7 @@ impl Executable<'_> { .inner; let debug_buffer = step - .debug_buffer + .trace_buffer .as_ref() .ok_or(ExecutionError::DebuggingError( "Failed to get debug buffer.", @@ -117,27 +130,6 @@ impl Executable<'_> { last_index = Some(index); } - //Dump all of our debug results - for (si, step) in self.steps.iter().enumerate() { - let d = device.clone(); - let dt = self.debug_list[si].dt(); - let debug_buffer = step.debug_buffer.clone().unwrap(); - let alignment = dt.size_of(); - let kernel_key = step.kernel_key.clone(); - #[cfg(target_arch = "wasm32")] - { - wasm_bindgen_futures::spawn_local(async move { - let cpu_buf = wgpu_buffer_to_cpu_buffer(&debug_buffer, alignment, d).await; - log::debug!("{}: {}\n {:?}\n", si, kernel_key, cpu_buf.dump(dt, false)); - }); - } - #[cfg(not(target_arch = "wasm32"))] - { - let cpu_buf = wgpu_buffer_to_cpu_buffer(&debug_buffer, alignment, &d); - log::debug!("{}: {}\n {:?}\n", si, kernel_key, cpu_buf.dump(dt, false)); - } - } - Ok(last_index.unwrap()) } diff --git a/crates/ratchet-core/src/op.rs b/crates/ratchet-core/src/op.rs index 0fc54a01..d7bdeb58 100644 --- a/crates/ratchet-core/src/op.rs +++ b/crates/ratchet-core/src/op.rs @@ -9,7 +9,7 @@ use crate::{ use std::borrow::Cow; use std::fmt::Debug; -#[cfg(feature = "debug")] +#[cfg(feature = "trace")] use {crate::gpu::BufferUsagesExt, std::sync::Arc}; #[derive(Clone, Debug)] @@ -286,7 +286,6 @@ pub trait GPUOperation: Operation { uniform: &mut CpuUniform, device: &WgpuDevice, can_inplace: bool, - debug: bool, ) -> Result { let kernel = self.select_kernel(); @@ -341,17 +340,13 @@ pub trait GPUOperation: Operation { can_inplace, )?; - #[cfg(feature = "debug")] - let debug_buffer = if debug { - Some(Arc::new(device.create_buffer(&wgpu::BufferDescriptor { - label: Some("debug buffer"), - size: dst.num_bytes() as _, - usage: wgpu::BufferUsages::standard(), - mapped_at_creation: false, - }))) - } else { - None - }; + #[cfg(feature = "trace")] + let trace_buffer = Some(Arc::new(device.create_buffer(&wgpu::BufferDescriptor { + label: Some("debug buffer"), + size: dst.num_bytes() as _, + usage: wgpu::BufferUsages::standard(), + mapped_at_creation: false, + }))); Ok(CompiledOp::new( pipeline_handle, @@ -359,8 +354,8 @@ pub trait GPUOperation: Operation { storage_bind_groups, offset as _, kernel_src_desc.key, - #[cfg(feature = "debug")] - debug_buffer, + #[cfg(feature = "trace")] + trace_buffer, )) } } diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 1dc4b430..9ccd5a64 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -719,30 +719,28 @@ impl Tensor { uniform: &mut CpuUniform, device: &WgpuDevice, can_ip: bool, - debug: bool, ) -> Option { match self.op() { - LazyOp::Binary(b) => b.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Cast(c) => c.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Matmul(m) => m.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Softmax(s) => s.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::RoPE(r) => r.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Unary(u) => u.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Reindex(r) => r.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Concat(c) => c.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Norm(n) => n.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Conv(c) => c.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Select(i) => i.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::IndexWrite(i) => i.compile_gpu(self, uniform, device, can_ip, debug).ok(), - LazyOp::Cache(c) => c.compile_gpu(self, uniform, device, can_ip, debug).ok(), + LazyOp::Binary(b) => b.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Cast(c) => c.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Matmul(m) => m.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Softmax(s) => s.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::RoPE(r) => r.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Unary(u) => u.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Reindex(r) => r.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Concat(c) => c.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Norm(n) => n.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Conv(c) => c.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Select(i) => i.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::IndexWrite(i) => i.compile_gpu(self, uniform, device, can_ip).ok(), + LazyOp::Cache(c) => c.compile_gpu(self, uniform, device, can_ip).ok(), LazyOp::Const => None, LazyOp::View(_) => None, } } - fn resolve_inner(self, debug: bool) -> Result { + fn create_executable(&self, device: &WgpuDevice) -> Result { let mut uniform = CpuUniform::new(); - let device = self.device().try_gpu()?; device.begin_pass(); let execution_order = self.execution_order(); @@ -750,12 +748,6 @@ impl Tensor { let mut compiled_ops = Vec::with_capacity(execution_order.len()); let mut allocations = device.allocate_cfg(&execution_order, device)?; - #[cfg(feature = "plotting")] - crate::plot::render_to_file(execution_order.last().unwrap(), "prealloc.svg").unwrap(); - - #[cfg(feature = "debug")] - let mut compute_dsts = Vec::new(); - for t in execution_order.iter() { log::debug!("Compiling: {:?}", t.op().name()); assert!(t.device().is_gpu()); @@ -773,51 +765,47 @@ impl Tensor { let to_modify = t.op().srcs()[0]; let can_inplace = t.op().supports_inplace() && to_modify.strong_count() == 1; - if let Some(compiled_op) = t.compile_gpu(&mut uniform, device, can_inplace, debug) { + if let Some(compiled_op) = t.compile_gpu(&mut uniform, device, can_inplace) { compiled_ops.push(compiled_op); - #[cfg(feature = "debug")] - compute_dsts.push(*t); } else { log::warn!("Compilation failed for operation: {:?}", t.op().name()); } } - #[cfg(feature = "plotting")] - crate::plot::render_to_file(execution_order.last().unwrap(), "alloc.svg").unwrap(); - - let executable = Executable::new( - compiled_ops, - uniform.into_gpu(device)?, - #[cfg(feature = "debug")] - compute_dsts, - ); - #[cfg(feature = "debug")] - let index = if debug { - if cfg!(feature = "debug") { - executable.dispatch_debugging(device).unwrap() - } else { - panic!("Debugging is only available in debug builds. Call `resolve()` instead of `resolve_debug()`.") - } - } else { - executable.dispatch(device).unwrap() - }; - #[cfg(not(feature = "debug"))] - let index = executable.dispatch(device).unwrap(); - device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index)); - Ok(self) + Ok(Executable::new(compiled_ops, uniform.into_gpu(device)?)) } + /// # Resolve + /// + /// All work in Ratchet is lazy, and no computation is done until `resolve` is called. + /// + /// Upon calling resolve, all work required to compute this `Tensor` is done. + /// The tensor is then returned with the underlying memory populated. pub fn resolve(self) -> Result { - self.resolve_inner(false) + let device = self.device.try_gpu()?; + let executable = self.create_executable(&device)?; + let index = executable.dispatch(device).unwrap(); + device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index)); + Ok(self) } - /// Resolves the tensor computations and copies the output - /// from each operation to a debug buffer. + /// # Tracing + /// Note: trace feature flag must be enabled + /// + /// WebGPU runs across heterogenous device, and as such, there + /// may be numerical differences. /// - /// The copy calls are inserted between each operation, so inplace - /// operations are captured. - pub fn resolve_debug(self) -> Result { - self.resolve_inner(true) + /// This method returns a `Vec` containing all the intermediate + /// tensors computed during the resolution of this tensor. The final + /// tensor in this list is the tensor upon which `resolve` was called. + #[cfg(feature = "trace")] + pub fn resolve_trace(self) -> Result, TensorError> { + let device = self.device.try_gpu()?; + let executable = self.create_executable(&device)?; + let index = executable.dispatch_trace(device).unwrap(); + device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index)); + + //Now the trace_list is populated } fn to_gpu(&self, dst_device: &Device) -> Result { diff --git a/crates/ratchet-models/src/whisper/decoder.rs b/crates/ratchet-models/src/whisper/decoder.rs index d912b46d..1375688a 100644 --- a/crates/ratchet-models/src/whisper/decoder.rs +++ b/crates/ratchet-models/src/whisper/decoder.rs @@ -340,7 +340,7 @@ def ground(options): Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); let result = decoder .schedule([audio_ctx.clone(), token_t])? - .resolve_debug()?; + .resolve_trace()?; let our_logits = result.to(&Device::CPU)?; let nd_logits = our_logits.to_ndarray_view::(); From 2dce94d747ca7035406c18f9889ee749e093169b Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Mon, 22 Jul 2024 13:35:33 +0100 Subject: [PATCH 3/6] chore: adding tracing --- crates/ratchet-core/Cargo.toml | 5 +- crates/ratchet-core/src/executable.rs | 21 ++------- crates/ratchet-core/src/tensor.rs | 49 ++++++++++++++++++-- crates/ratchet-core/src/tensor_id.rs | 6 +++ crates/ratchet-models/src/whisper/decoder.rs | 4 +- crates/ratchet-models/tests/whisper.rs | 2 +- 6 files changed, 60 insertions(+), 27 deletions(-) diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 6c56c7d2..4edc5555 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -10,7 +10,7 @@ rand = ["dep:rand", "dep:rand_distr"] plotting = ["dep:dot3", "dep:tempfile"] testing = ["dep:npyz", "dep:ndarray"] pyo3 = ["dep:pyo3", "dep:numpy", "dep:regex"] -trace = ["dep:uuid"] +trace = ["dep:uuid", "dep:web-time"] [build-dependencies] tera = { workspace = true } @@ -46,7 +46,7 @@ num = { workspace = true } rand_distr = { workspace = true, optional = true } rand = { workspace = true, optional = true } glam = { workspace = true } -npyz = { workspace = true, optional = true } +npyz = { workspace = true, optional = true, features=["half"] } ndarray = { workspace = true, optional = true } #Plotting @@ -63,6 +63,7 @@ numpy = { workspace = true, optional = true, features=["half"]} # Trace uuid = { version="1.1.0", features = ["v4", "fast-rng"], optional = true } +web-time = { workspace = true, optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/crates/ratchet-core/src/executable.rs b/crates/ratchet-core/src/executable.rs index 98d154be..85edbab8 100644 --- a/crates/ratchet-core/src/executable.rs +++ b/crates/ratchet-core/src/executable.rs @@ -15,10 +15,10 @@ use crate::Tensor; /// containing metadata for all operations. #[derive(new)] pub struct Executable<'t> { - steps: Vec, - gpu_uniform: GpuUniform, + pub(crate) steps: Vec, + pub(crate) gpu_uniform: GpuUniform, #[cfg(feature = "trace")] - trace_list: Vec<&'t Tensor>, + pub(crate) trace_list: Vec<&'t Tensor>, #[cfg(not(feature = "trace"))] _phantom: PhantomData<&'t ()>, } @@ -32,21 +32,6 @@ pub enum ExecutionError { DebuggingError(&'static str), } -impl Executable<'_> { - pub fn steps(&self) -> &[CompiledOp] { - &self.steps - } - - pub fn gpu_uniform(&self) -> &GpuUniform { - &self.gpu_uniform - } - - #[cfg(feature = "trace")] - pub fn debug_list(&self) -> &[&Tensor] { - &self.trace_list - } -} - impl Executable<'_> { #[cfg(not(feature = "gpu-profiling"))] pub fn dispatch(&self, device: &WgpuDevice) -> Result { diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 9ccd5a64..d7a7c3d3 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -34,6 +34,8 @@ pub enum TensorError { #[error("Failed to transfer data to host")] TransferError, #[error(transparent)] + IoError(#[from] std::io::Error), + #[error(transparent)] OperationError(#[from] OperationError), } @@ -748,6 +750,8 @@ impl Tensor { let mut compiled_ops = Vec::with_capacity(execution_order.len()); let mut allocations = device.allocate_cfg(&execution_order, device)?; + #[cfg(feature = "trace")] + let mut trace_list = Vec::with_capacity(execution_order.len()); for t in execution_order.iter() { log::debug!("Compiling: {:?}", t.op().name()); assert!(t.device().is_gpu()); @@ -767,12 +771,19 @@ impl Tensor { if let Some(compiled_op) = t.compile_gpu(&mut uniform, device, can_inplace) { compiled_ops.push(compiled_op); + #[cfg(feature = "trace")] + trace_list.push(*t); } else { log::warn!("Compilation failed for operation: {:?}", t.op().name()); } } - Ok(Executable::new(compiled_ops, uniform.into_gpu(device)?)) + Ok(Executable::new( + compiled_ops, + uniform.into_gpu(device)?, + #[cfg(feature = "trace")] + trace_list, + )) } /// # Resolve @@ -799,13 +810,45 @@ impl Tensor { /// tensors computed during the resolution of this tensor. The final /// tensor in this list is the tensor upon which `resolve` was called. #[cfg(feature = "trace")] - pub fn resolve_trace(self) -> Result, TensorError> { + pub fn trace(self) -> Result, TensorError> { let device = self.device.try_gpu()?; let executable = self.create_executable(&device)?; let index = executable.dispatch_trace(device).unwrap(); device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index)); - //Now the trace_list is populated + let Executable { trace_list, .. } = executable; + let mut result = trace_list.iter().map(|t| (*t).clone()).collect::>(); + result.push(self); + Ok(result) + } + + /// # Serialization + /// + /// We may want to serialize a trace to disk to determine platform discrepancies. + /// + /// This method does the following: + /// 1. Creates a trace directory with a UUID, time, and device details + /// 2. Serializes each tensor in the trace to disk, with the name being the tensor ID + #[cfg(feature = "trace")] + pub fn serialize_trace(trace: Vec) -> Result<(), TensorError> { + use half::f16; + use web_time::Instant; + let trace_dir = format!( + "trace-{}-{:?}", + uuid::Uuid::new_v4().to_string(), + Instant::now() + ); + std::fs::create_dir(&trace_dir).map_err(|e| TensorError::IoError(e))?; + for t in trace.iter() { + let id = t.id(); + let path = format!("ratchet-{}/{}.npy", trace_dir, id); + match t.dt() { + DType::F16 => t.write_npy::(&path), + DType::F32 => t.write_npy::(&path), + _ => unimplemented!(), + }; + } + Ok(()) } fn to_gpu(&self, dst_device: &Device) -> Result { diff --git a/crates/ratchet-core/src/tensor_id.rs b/crates/ratchet-core/src/tensor_id.rs index cf820269..cbbde1f9 100644 --- a/crates/ratchet-core/src/tensor_id.rs +++ b/crates/ratchet-core/src/tensor_id.rs @@ -8,6 +8,12 @@ impl std::fmt::Debug for TensorId { } } +impl std::fmt::Display for TensorId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "T{}", self.0) + } +} + impl TensorId { pub(crate) fn new() -> Self { // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 diff --git a/crates/ratchet-models/src/whisper/decoder.rs b/crates/ratchet-models/src/whisper/decoder.rs index 1375688a..b63f4457 100644 --- a/crates/ratchet-models/src/whisper/decoder.rs +++ b/crates/ratchet-models/src/whisper/decoder.rs @@ -338,9 +338,7 @@ def ground(options): while tokens[tokens.len() - 1] != 50257 { let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); - let result = decoder - .schedule([audio_ctx.clone(), token_t])? - .resolve_trace()?; + let result = decoder.schedule([audio_ctx.clone(), token_t])?.trace()?; let our_logits = result.to(&Device::CPU)?; let nd_logits = our_logits.to_ndarray_view::(); diff --git a/crates/ratchet-models/tests/whisper.rs b/crates/ratchet-models/tests/whisper.rs index 1392f498..539a520e 100644 --- a/crates/ratchet-models/tests/whisper.rs +++ b/crates/ratchet-models/tests/whisper.rs @@ -79,7 +79,7 @@ async fn tiny_decoder() -> Result<(), JsValue> { let result = decoder .schedule([audio_ctx.clone(), token_t]) .unwrap() - .resolve_debug() + .resolve() .unwrap(); let our_logits = result.to(&Device::CPU).await.unwrap(); From bafe6570f3656904a8fe5855af433d5e87d59c9d Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Tue, 23 Jul 2024 13:26:13 +0100 Subject: [PATCH 4/6] chore: tracing --- .gitignore | 3 +- crates/ratchet-core/src/lib.rs | 2 + crates/ratchet-core/src/tensor.rs | 36 ++---------- crates/ratchet-core/src/trace.rs | 54 +++++++++++++++++ crates/ratchet-models/src/whisper/decoder.rs | 62 +++++++++++--------- crates/ratchet-models/tests/whisper.rs | 2 +- 6 files changed, 98 insertions(+), 61 deletions(-) create mode 100644 crates/ratchet-core/src/trace.rs diff --git a/.gitignore b/.gitignore index 6d94af04..580aa8ab 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ models/** # Python local env venv/ -.venv/ \ No newline at end of file +.venv/ +**/trace-* diff --git a/crates/ratchet-core/src/lib.rs b/crates/ratchet-core/src/lib.rs index 8bfdd6e5..b058f0a5 100644 --- a/crates/ratchet-core/src/lib.rs +++ b/crates/ratchet-core/src/lib.rs @@ -15,6 +15,7 @@ mod storage; mod strides; mod tensor; mod tensor_id; +mod trace; pub use compiled_op::*; pub use device::*; @@ -31,6 +32,7 @@ pub use storage::*; pub use strides::*; pub use tensor::*; pub use tensor_id::*; +pub use trace::*; #[cfg(feature = "plotting")] pub use plot::render_to_file; diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index d7a7c3d3..7bbeb1f1 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -22,6 +22,9 @@ use ndarray::{ArrayD, ArrayViewD, Dimension}; #[cfg(all(not(target_arch = "wasm32"), feature = "pyo3"))] use numpy::PyArrayDyn; +#[cfg(feature = "trace")] +use crate::Trace; + // thiserror error for Tensor #[derive(thiserror::Error, Debug)] pub enum TensorError { @@ -810,7 +813,7 @@ impl Tensor { /// tensors computed during the resolution of this tensor. The final /// tensor in this list is the tensor upon which `resolve` was called. #[cfg(feature = "trace")] - pub fn trace(self) -> Result, TensorError> { + pub fn trace(self) -> Result { let device = self.device.try_gpu()?; let executable = self.create_executable(&device)?; let index = executable.dispatch_trace(device).unwrap(); @@ -819,36 +822,7 @@ impl Tensor { let Executable { trace_list, .. } = executable; let mut result = trace_list.iter().map(|t| (*t).clone()).collect::>(); result.push(self); - Ok(result) - } - - /// # Serialization - /// - /// We may want to serialize a trace to disk to determine platform discrepancies. - /// - /// This method does the following: - /// 1. Creates a trace directory with a UUID, time, and device details - /// 2. Serializes each tensor in the trace to disk, with the name being the tensor ID - #[cfg(feature = "trace")] - pub fn serialize_trace(trace: Vec) -> Result<(), TensorError> { - use half::f16; - use web_time::Instant; - let trace_dir = format!( - "trace-{}-{:?}", - uuid::Uuid::new_v4().to_string(), - Instant::now() - ); - std::fs::create_dir(&trace_dir).map_err(|e| TensorError::IoError(e))?; - for t in trace.iter() { - let id = t.id(); - let path = format!("ratchet-{}/{}.npy", trace_dir, id); - match t.dt() { - DType::F16 => t.write_npy::(&path), - DType::F32 => t.write_npy::(&path), - _ => unimplemented!(), - }; - } - Ok(()) + Ok(Trace::new(result)) } fn to_gpu(&self, dst_device: &Device) -> Result { diff --git a/crates/ratchet-core/src/trace.rs b/crates/ratchet-core/src/trace.rs new file mode 100644 index 00000000..1e46fd2d --- /dev/null +++ b/crates/ratchet-core/src/trace.rs @@ -0,0 +1,54 @@ +#![cfg(feature = "trace")] +use crate::{DType, Tensor}; + +/// # Trace +/// +/// All intermediate products & result of a computation. +pub struct Trace(Vec); + +impl Trace { + pub fn new(t: Vec) -> Self { + Self(t) + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn iter_mut(&mut self) -> impl Iterator { + self.0.iter_mut() + } + + /// # Serialization + /// + /// We may want to serialize a trace to disk to determine platform discrepancies. + /// + /// This method does the following: + /// 1. Creates a trace directory with a UUID, time, and device details + /// 2. Serializes each tensor in the trace to disk, with the name being the tensor ID + pub fn serialize(&self) -> Result<(), anyhow::Error> { + log::warn!("Serializing trace to disk"); + use half::f16; + + let trace_dir = format!("trace-{}", uuid::Uuid::new_v4().to_string(),); + std::fs::create_dir(&trace_dir).map_err(|e| anyhow::anyhow!(e))?; + for t in self.iter() { + let id = t.id(); + let path = format!("{}/ratchet-{}.npy", trace_dir, id); + let _ = match t.dt() { + DType::F16 => t.write_npy::(&path), + DType::F32 => t.write_npy::(&path), + _ => unimplemented!(), + }; + } + Ok(()) + } +} + +impl Iterator for Trace { + type Item = Tensor; + + fn next(&mut self) -> Option { + self.0.pop() + } +} diff --git a/crates/ratchet-models/src/whisper/decoder.rs b/crates/ratchet-models/src/whisper/decoder.rs index b63f4457..657a1f3d 100644 --- a/crates/ratchet-models/src/whisper/decoder.rs +++ b/crates/ratchet-models/src/whisper/decoder.rs @@ -307,13 +307,12 @@ def ground(options): log_init(); let api = Api::new().unwrap(); let model = api.model("FL33TW00D-HF/whisper-tiny".to_string()); - let path = model.get("tiny_f32.gguf").unwrap(); - let config_path = model.get("config.json").unwrap(); - let config: Config = serde_json::from_slice(&std::fs::read(config_path).unwrap()).unwrap(); + let path = model.get("tiny_f32.gguf")?; + let config_path = model.get("config.json")?; + let config: Config = serde_json::from_slice(&std::fs::read(config_path)?)?; println!("MODEL LOADED FROM: {}", path.display()); let dataset = api.dataset("FL33TW00D-HF/ratchet-util".to_string()); - let options = DecodingOptionsBuilder::new().build(); let hs_npy = dataset.get("jfk_tiny_encoder_hs.npy").unwrap(); let audio_path = dataset.get("jfk.wav").unwrap(); @@ -338,30 +337,36 @@ def ground(options): while tokens[tokens.len() - 1] != 50257 { let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); - let result = decoder.schedule([audio_ctx.clone(), token_t])?.trace()?; - - let our_logits = result.to(&Device::CPU)?; - let nd_logits = our_logits.to_ndarray_view::(); - println!("ND LOGITS: {:?}", nd_logits); - all_logits.push(Tensor::from( - nd_logits - .slice(s![.., .., ..tokenizer.get_vocab_size(true)]) - .to_owned() - .into_dyn(), - )); - - let sliced = nd_logits - .slice(s![.., -1.., ..tokenizer.get_vocab_size(true)]) - .remove_axis(Axis(1)); - decoder.cache_mut().update(tokens.len()); - - tokens = sliced - .map_axis(Axis(1), |row| row.argmax_skipnan().unwrap()) - .iter() - .map(|&x| x as i32) - .collect::>(); - println!("Token: {:?}", tokens); - all_tokens.extend(tokens.clone()); + let mut trace = decoder.schedule([audio_ctx.clone(), token_t])?.trace()?; + trace + .iter_mut() + .for_each(|t| *t = t.to(&Device::CPU).unwrap()); + + let _ = trace.serialize()?; + panic!("DONE"); + + //let our_logits = result.to(&Device::CPU)?; + //let nd_logits = our_logits.to_ndarray_view::(); + //println!("ND LOGITS: {:?}", nd_logits); + //all_logits.push(Tensor::from( + // nd_logits + // .slice(s![.., .., ..tokenizer.get_vocab_size(true)]) + // .to_owned() + // .into_dyn(), + //)); + + //let sliced = nd_logits + // .slice(s![.., -1.., ..tokenizer.get_vocab_size(true)]) + // .remove_axis(Axis(1)); + //decoder.cache_mut().update(tokens.len()); + + //tokens = sliced + // .map_axis(Axis(1), |row| row.argmax_skipnan().unwrap()) + // .iter() + // .map(|&x| x as i32) + // .collect::>(); + //println!("Token: {:?}", tokens); + //all_tokens.extend(tokens.clone()); } println!("Took: {:?}", start.elapsed()); @@ -370,6 +375,7 @@ def ground(options): println!("All tokens: {:?}", all_tokens); println!("Decoded: {}", decoded); + let options = DecodingOptionsBuilder::new().build(); let ground_logits = ground_truth(&audio_path.to_string_lossy(), options)?; let all_equal = all_logits .iter() diff --git a/crates/ratchet-models/tests/whisper.rs b/crates/ratchet-models/tests/whisper.rs index 539a520e..897c2442 100644 --- a/crates/ratchet-models/tests/whisper.rs +++ b/crates/ratchet-models/tests/whisper.rs @@ -52,7 +52,7 @@ async fn tiny_encoder() -> Result<(), JsValue> { async fn tiny_decoder() -> Result<(), JsValue> { log_init(); let model_repo = ApiBuilder::from_hf("FL33TW00D-HF/whisper-tiny", RepoType::Model).build(); - let model_data = model_repo.get("tiny_q8_0.gguf").await?; + let model_data = model_repo.get("tiny_f32.gguf").await?; let config_data = model_repo.get("config.json").await?; let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build(); From d41f029afd1317be9f4111cfd2bbd44f3be23221 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Tue, 23 Jul 2024 15:28:02 +0100 Subject: [PATCH 5/6] chore: descent into madness --- Cargo.toml | 1 + crates/ratchet-core/Cargo.toml | 8 +-- crates/ratchet-core/src/device.rs | 7 ++ crates/ratchet-core/src/gpu/device.rs | 37 ++++++++++ crates/ratchet-core/src/tensor.rs | 6 +- crates/ratchet-core/src/trace.rs | 65 +++++++++++++++-- crates/ratchet-models/Cargo.toml | 1 + crates/ratchet-models/src/whisper/decoder.rs | 4 +- crates/ratchet-models/src/whisper/encoder.rs | 14 +++- crates/ratchet-models/tests/whisper.rs | 73 ++++++++++++++------ 10 files changed, 178 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e8dc228d..4d8468c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,3 +93,4 @@ wasm-bindgen-futures = "0.4.42" web-sys = "0.3.69" web-time = "1.0.0" futures-intrusive = "0.5.0" +include_dir = "0.7.4" diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 4edc5555..cf55c361 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -34,11 +34,10 @@ log = { workspace = true } thiserror = { workspace = true } serde = { workspace = true, features = ["derive"] } anyhow.workspace = true - +smallvec = { workspace = true , features = ["serde"] } rustc-hash = { workspace = true } slotmap = { workspace = true } parking_lot = { workspace = true } -smallvec = { workspace = true } encase = { workspace = true, features = ["smallvec", "glam"] } pollster = { workspace = true } getrandom = { workspace = true, features = ["js"] } # Needed for wasm support in `num` trait @@ -64,15 +63,16 @@ numpy = { workspace = true, optional = true, features=["half"]} # Trace uuid = { version="1.1.0", features = ["v4", "fast-rng"], optional = true } web-time = { workspace = true, optional = true } +serde_json.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen.workspace = true futures-intrusive.workspace = true wasm-bindgen-futures.workspace = true - async-trait = "0.1.77" -smallvec = { workspace = true , features = ["serde"] } +include_dir.workspace = true +smallvec = { workspace = true } [dev-dependencies] env_logger = { workspace = true } diff --git a/crates/ratchet-core/src/device.rs b/crates/ratchet-core/src/device.rs index 038578a4..dcd2ace1 100644 --- a/crates/ratchet-core/src/device.rs +++ b/crates/ratchet-core/src/device.rs @@ -81,6 +81,13 @@ impl Device { format!("{:?}", self) } + pub fn device_identifier(&self) -> String { + match self { + Device::CPU => "CPU".to_string(), + Device::GPU(gpu) => gpu.info().device_identifier(), + } + } + pub fn try_gpu(&self) -> Result<&WgpuDevice, DeviceError> { match self { Device::GPU(gpu) => Ok(gpu), diff --git a/crates/ratchet-core/src/gpu/device.rs b/crates/ratchet-core/src/gpu/device.rs index 3dc053ce..47143edc 100644 --- a/crates/ratchet-core/src/gpu/device.rs +++ b/crates/ratchet-core/src/gpu/device.rs @@ -23,6 +23,7 @@ pub struct WgpuDevice { pipeline_layout_pool: Arc, compute_pipeline_pool: Arc, kernel_module_pool: Arc, + device_info: DeviceInfo, device_limits: DeviceLimits, device_features: DeviceFeatures, device: Arc, @@ -112,6 +113,7 @@ impl WgpuDevice { kernel_module_pool: Arc::new(KernelModulePool::new()), compute_pipeline_pool: Arc::new(ComputePipelinePool::new()), device: Arc::new(device), + device_info: adapter.get_info().into(), device_limits: limits, device_features: features, }) @@ -167,6 +169,10 @@ impl WgpuDevice { pub fn limits(&self) -> &DeviceLimits { &self.device_limits } + + pub fn info(&self) -> &DeviceInfo { + &self.device_info + } } impl WgpuDevice { @@ -284,6 +290,37 @@ impl WgpuDevice { } } +#[derive(Clone)] +pub struct DeviceInfo { + pub name: String, + pub vendor: u32, + pub device: u32, + pub device_type: wgpu::DeviceType, + pub driver: String, + pub driver_info: String, + pub backend: wgpu::Backend, +} + +impl DeviceInfo { + pub fn device_identifier(&self) -> String { + format!("{}-{}", self.name.replace(" ", "-"), self.backend.to_str()) + } +} + +impl From for DeviceInfo { + fn from(info: wgpu::AdapterInfo) -> Self { + DeviceInfo { + name: info.name, + vendor: info.vendor, + device: info.device, + device_type: info.device_type, + driver: info.driver, + driver_info: info.driver_info, + backend: info.backend, + } + } +} + #[derive(Clone)] pub struct DeviceLimits { pub max_bind_groups: u32, diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 7bbeb1f1..e50e5262 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -821,7 +821,11 @@ impl Tensor { let Executable { trace_list, .. } = executable; let mut result = trace_list.iter().map(|t| (*t).clone()).collect::>(); - result.push(self); + + for bingo in &result { + log::warn!("TENSOR ID: {:?}", bingo.id()); + } + Ok(Trace::new(result)) } diff --git a/crates/ratchet-core/src/trace.rs b/crates/ratchet-core/src/trace.rs index 1e46fd2d..a2e9817e 100644 --- a/crates/ratchet-core/src/trace.rs +++ b/crates/ratchet-core/src/trace.rs @@ -1,5 +1,5 @@ #![cfg(feature = "trace")] -use crate::{DType, Tensor}; +use crate::{DType, Device, Tensor}; /// # Trace /// @@ -26,23 +26,52 @@ impl Trace { /// This method does the following: /// 1. Creates a trace directory with a UUID, time, and device details /// 2. Serializes each tensor in the trace to disk, with the name being the tensor ID - pub fn serialize(&self) -> Result<(), anyhow::Error> { + pub fn serialize(&self, device: &Device) -> Result<(), anyhow::Error> { log::warn!("Serializing trace to disk"); - use half::f16; + let device_identifier = device.device_identifier(); - let trace_dir = format!("trace-{}", uuid::Uuid::new_v4().to_string(),); + let trace_dir = format!( + "trace-{}-{}", + uuid::Uuid::new_v4().to_string(), + device_identifier + ); std::fs::create_dir(&trace_dir).map_err(|e| anyhow::anyhow!(e))?; + + let mut metadata = Vec::new(); for t in self.iter() { - let id = t.id(); - let path = format!("{}/ratchet-{}.npy", trace_dir, id); + let name = format!("ratchet-{}.npy", t.id()); + let path = format!("{}/{}", trace_dir, name); let _ = match t.dt() { - DType::F16 => t.write_npy::(&path), DType::F32 => t.write_npy::(&path), _ => unimplemented!(), }; + metadata.push(name); + } + + let metadata_path = format!("{}/metadata.json", trace_dir); + std::fs::write(metadata_path, serde_json::to_string(&metadata)?)?; + Ok(()) + } + + pub fn compare(&self, other: &Self, atol: f32, rtol: f32) -> Result<(), anyhow::Error> { + assert_eq!(self.0.len(), other.0.len()); + log::warn!("Comparing traces"); + for (a, b) in self.iter().zip(other.iter()) { + log::warn!("A: {:?}", a); + log::warn!("B: {:?}", b); + log::warn!("Comparing tensor {} to tensor {}", a.id(), b.id()); + a.all_close(b, atol, rtol)?; } Ok(()) } + + pub fn first(&self) -> Option<&Tensor> { + self.0.first() + } + + pub fn pop(&mut self) -> Option { + self.0.pop() + } } impl Iterator for Trace { @@ -52,3 +81,25 @@ impl Iterator for Trace { self.0.pop() } } + +#[cfg(target_arch = "wasm32")] +impl Trace { + pub fn deserialize(dir: include_dir::Dir) -> Result { + let metadata_file = dir + .get_file("metadata.json") + .ok_or_else(|| anyhow::anyhow!("Metadata file not found"))?; + let metadata: Vec = serde_json::from_slice(metadata_file.contents())?; + + let mut tensors = Vec::with_capacity(metadata.len()); + + for filename in metadata { + let file = dir + .get_file(&filename) + .ok_or_else(|| anyhow::anyhow!("File not found: {}", filename))?; + let tensor = Tensor::from_npy_bytes::(file.contents(), &Device::CPU)?; + tensors.push(tensor); + } + + Ok(Self::new(tensors)) + } +} diff --git a/crates/ratchet-models/Cargo.toml b/crates/ratchet-models/Cargo.toml index 8e4a5545..c609ec46 100644 --- a/crates/ratchet-models/Cargo.toml +++ b/crates/ratchet-models/Cargo.toml @@ -65,6 +65,7 @@ wasm-bindgen-futures = { workspace = true } npyz = { workspace = true } hound = { workspace = true } env_logger = { workspace = true } +include_dir.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] ratchet = { path = "../ratchet-core", features = ["pyo3"] } diff --git a/crates/ratchet-models/src/whisper/decoder.rs b/crates/ratchet-models/src/whisper/decoder.rs index 657a1f3d..d887fe47 100644 --- a/crates/ratchet-models/src/whisper/decoder.rs +++ b/crates/ratchet-models/src/whisper/decoder.rs @@ -337,12 +337,14 @@ def ground(options): while tokens[tokens.len() - 1] != 50257 { let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); + log::warn!("AUDIO: {:?}", audio_ctx); + log::warn!("TOKENS: {:?}", token_t); let mut trace = decoder.schedule([audio_ctx.clone(), token_t])?.trace()?; trace .iter_mut() .for_each(|t| *t = t.to(&Device::CPU).unwrap()); - let _ = trace.serialize()?; + let _ = trace.serialize(&device)?; panic!("DONE"); //let our_logits = result.to(&Device::CPU)?; diff --git a/crates/ratchet-models/src/whisper/encoder.rs b/crates/ratchet-models/src/whisper/encoder.rs index 482ffe3f..106f8fc9 100644 --- a/crates/ratchet-models/src/whisper/encoder.rs +++ b/crates/ratchet-models/src/whisper/encoder.rs @@ -238,8 +238,18 @@ mod tests { let encoder = WhisperEncoder::load(&header, &config, &mut reader, &device)?; let input = Tensor::read_npy::(input_npy, &device)?; - let result = encoder.schedule(input)?.full()?.resolve()?; - let ours = result.to(&Device::CPU)?; + let mut trace = encoder.schedule(input)?.full()?.trace()?; + + trace + .iter_mut() + .for_each(|t| *t = t.to(&Device::CPU).unwrap()); + + trace.serialize(&device); + //let ours = result.to(&Device::CPU)?; + + log::warn!("TRACE FIRST: {:#?}", trace.first().unwrap()); + let ours = trace.pop().unwrap(); + log::warn!("RESULT: {:#?}", ours); let ground = Tensor::read_npy::(ground_npy, &Device::CPU)?; println!("OURS: {:#?}", ours); println!("Ground: {:#?}", ground); diff --git a/crates/ratchet-models/tests/whisper.rs b/crates/ratchet-models/tests/whisper.rs index 897c2442..a2a44160 100644 --- a/crates/ratchet-models/tests/whisper.rs +++ b/crates/ratchet-models/tests/whisper.rs @@ -18,9 +18,10 @@ fn log_init() { console_log::init_with_level(log::Level::Debug).unwrap(); } -/* #[wasm_bindgen_test] async fn tiny_encoder() -> Result<(), JsValue> { + use include_dir::include_dir; + use ratchet::Trace; log_init(); let model_repo = ApiBuilder::from_hf("FL33TW00D-HF/whisper-tiny", RepoType::Model).build(); let model_data = model_repo.get("tiny_f32.gguf").await?; @@ -41,13 +42,27 @@ async fn tiny_encoder() -> Result<(), JsValue> { let ground = Tensor::from_npy_bytes::(&ground_npy.to_vec(), &Device::CPU).unwrap(); let encoder = WhisperEncoder::load(&header, &config, &mut reader, &device).unwrap(); - let result = encoder.schedule(input).unwrap().resolve().unwrap(); + let mut web_trace = encoder.schedule(input).unwrap().trace().unwrap(); + + for t in web_trace.iter_mut() { + *t = t.to(&Device::CPU).await.unwrap(); + } + + log::warn!("TRACE FIRST: {:?}", web_trace.first().unwrap()); + let result = web_trace.pop().unwrap(); + + log::warn!("TRACE RESULT: {:?}", result); + + //let trace = + // Trace::deserialize(include_dir!("/Users/fleetwood/Code/ratchet/crates/ratchet-models/trace-221df6bf-82be-419e-bf79-b00da2964c95-Apple-M3-Max-metal")).unwrap(); + //trace.compare(&web_trace, 1e-3, 1e-3).unwrap(); + let ours = result.to(&Device::CPU).await.unwrap(); ground.all_close(&ours, 1e-3, 1e-3).unwrap(); Ok(()) } -*/ +/* #[wasm_bindgen_test] async fn tiny_decoder() -> Result<(), JsValue> { log_init(); @@ -64,40 +79,51 @@ async fn tiny_decoder() -> Result<(), JsValue> { let device = Device::request_device(DeviceRequest::GPU).await.unwrap(); - let audio_ctx = Tensor::from_npy_bytes::(&hs_data.to_vec(), &device) + let mut audio_ctx = Tensor::from_npy_bytes::(&hs_data.to_vec(), &device) .unwrap() .cast(device.compute_precision()) .unwrap(); + let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device).unwrap(); let mut tokens = vec![50258, 50259, 50359]; let mut all_tokens = tokens.clone(); - let mut all_logits = vec![]; + //let mut all_logits = vec![]; let vocab_size = 51865; while tokens[tokens.len() - 1] != 50257 { let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone()); - let result = decoder + let mut web_trace = decoder .schedule([audio_ctx.clone(), token_t]) .unwrap() - .resolve() + .trace() .unwrap(); - let our_logits = result.to(&Device::CPU).await.unwrap(); - all_logits.push(our_logits.clone()); - let nd_logits = our_logits.to_ndarray_view::(); - log::debug!("ND LOGITS: {:?}", nd_logits); - let sliced = nd_logits - .slice(s![.., -1.., ..vocab_size]) - .remove_axis(Axis(1)); - decoder.cache_mut().update(tokens.len()); - - tokens = sliced - .map_axis(Axis(1), |row| row.argmax_skipnan().unwrap()) - .iter() - .map(|&x| x as i32) - .collect::>(); - println!("Token: {:?}", tokens); - all_tokens.extend(tokens.clone()); + for t in web_trace.iter_mut() { + *t = t.to(&Device::CPU).await.unwrap(); + } + + let trace = + Trace::deserialize(include_dir!("/Users/fleetwood/Code/ratchet/crates/ratchet-models/trace-f79dfd98-fd7b-4dbf-96a8-e4552021aacc-Apple-M3-Max-metal")).unwrap(); + trace.compare(&web_trace, 1e-3, 1e-3).unwrap(); + + panic!(""); + + //let our_logits = result.to(&Device::CPU).await.unwrap(); + //all_logits.push(our_logits.clone()); + //let nd_logits = our_logits.to_ndarray_view::(); + //log::debug!("ND LOGITS: {:?}", nd_logits); + //let sliced = nd_logits + // .slice(s![.., -1.., ..vocab_size]) + // .remove_axis(Axis(1)); + //decoder.cache_mut().update(tokens.len()); + + //tokens = sliced + // .map_axis(Axis(1), |row| row.argmax_skipnan().unwrap()) + // .iter() + // .map(|&x| x as i32) + // .collect::>(); + //println!("Token: {:?}", tokens); + //all_tokens.extend(tokens.clone()); } let ground_tokens = vec![ @@ -107,3 +133,4 @@ async fn tiny_decoder() -> Result<(), JsValue> { assert_eq!(all_tokens, ground_tokens); Ok(()) } +*/ From a95409b8376adf0c8050e6b9d4333cafa060e1d1 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Tue, 23 Jul 2024 16:51:54 +0100 Subject: [PATCH 6/6] chore: ridiculous sync issues between local and browser --- crates/ratchet-core/src/executable.rs | 4 +- crates/ratchet-core/src/gpu/device.rs | 1 + crates/ratchet-core/src/tensor.rs | 6 +- crates/ratchet-core/src/trace.rs | 1 - crates/ratchet-models/src/whisper/encoder.rs | 61 +++++++++++++++- crates/ratchet-models/src/whisper/mod.rs | 1 + .../src/whisper/residual_block.rs | 25 +++---- crates/ratchet-models/tests/whisper.rs | 69 ++++++++++++++++++- 8 files changed, 144 insertions(+), 24 deletions(-) diff --git a/crates/ratchet-core/src/executable.rs b/crates/ratchet-core/src/executable.rs index 85edbab8..d4d27776 100644 --- a/crates/ratchet-core/src/executable.rs +++ b/crates/ratchet-core/src/executable.rs @@ -103,13 +103,13 @@ impl Executable<'_> { .map_err(|_| ExecutionError::DebuggingError("Failed to get result buf."))? .inner; - let debug_buffer = step + let trace_buffer = step .trace_buffer .as_ref() .ok_or(ExecutionError::DebuggingError( "Failed to get debug buffer.", ))?; - encoder.copy_buffer_to_buffer(result_buf, 0, debug_buffer, 0, debug_buffer.size()); + encoder.copy_buffer_to_buffer(result_buf, 0, trace_buffer, 0, trace_buffer.size()); let index = device.queue().submit(Some(encoder.finish())); last_index = Some(index); diff --git a/crates/ratchet-core/src/gpu/device.rs b/crates/ratchet-core/src/gpu/device.rs index 47143edc..68dcf207 100644 --- a/crates/ratchet-core/src/gpu/device.rs +++ b/crates/ratchet-core/src/gpu/device.rs @@ -77,6 +77,7 @@ impl WgpuDevice { max_compute_invocations_per_workgroup: 1024, ..Default::default() }, + memory_hints: wgpu::MemoryHints::default(), }; let device_request = adapter.request_device(&device_descriptor, None).await; let (device, queue) = if let Err(e) = device_request { diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index e50e5262..e98718b6 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -820,11 +820,7 @@ impl Tensor { device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index)); let Executable { trace_list, .. } = executable; - let mut result = trace_list.iter().map(|t| (*t).clone()).collect::>(); - - for bingo in &result { - log::warn!("TENSOR ID: {:?}", bingo.id()); - } + let result = trace_list.iter().map(|t| (*t).clone()).collect::>(); Ok(Trace::new(result)) } diff --git a/crates/ratchet-core/src/trace.rs b/crates/ratchet-core/src/trace.rs index a2e9817e..fec0c613 100644 --- a/crates/ratchet-core/src/trace.rs +++ b/crates/ratchet-core/src/trace.rs @@ -59,7 +59,6 @@ impl Trace { for (a, b) in self.iter().zip(other.iter()) { log::warn!("A: {:?}", a); log::warn!("B: {:?}", b); - log::warn!("Comparing tensor {} to tensor {}", a.id(), b.id()); a.all_close(b, atol, rtol)?; } Ok(()) diff --git a/crates/ratchet-models/src/whisper/encoder.rs b/crates/ratchet-models/src/whisper/encoder.rs index 106f8fc9..006624ab 100644 --- a/crates/ratchet-models/src/whisper/encoder.rs +++ b/crates/ratchet-models/src/whisper/encoder.rs @@ -97,7 +97,7 @@ impl EncoderStem { #[derive(Debug)] pub struct WhisperEncoder { stem: EncoderStem, - blocks: Vec, + pub blocks: Vec, ln_post: LayerNorm, activation_dt: DType, } @@ -207,16 +207,71 @@ impl WhisperEncoder { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use hf_hub::api::sync::Api; - use ratchet::{Device, DeviceRequest, Tensor}; + use ratchet::{prelude::shape, Device, DeviceRequest, Tensor}; use ratchet_loader::gguf::gguf; use ratchet_nn::Module; - use crate::whisper::{config::Config, encoder::WhisperEncoder}; + use crate::whisper::{ + config::Config, encoder::WhisperEncoder, residual_block::ResidualAttentionBlockInputs, + }; fn log_init() { let _ = env_logger::builder().is_test(true).try_init(); } + #[test] + fn tiny_test() -> anyhow::Result<()> { + log_init(); + let api = Api::new().unwrap(); + let model = api.model("FL33TW00D-HF/whisper-tiny".to_string()); + let model_path = model.get("tiny_f32.gguf").unwrap(); + println!("Path: {}", model_path.display()); + let config_path = model.get("config.json").unwrap(); + + let dataset = api.dataset("FL33TW00D-HF/ratchet-util".to_string()); + let input_npy = dataset.get("jfk_tiny_encoder_input.npy").unwrap(); + let ground_npy = dataset.get("jfk_tiny_encoder_hs.npy").unwrap(); + + let mut reader = std::io::BufReader::new(std::fs::File::open(model_path).unwrap()); + let header = gguf::Header::read(&mut reader).unwrap(); + let config: Config = serde_json::from_slice(&std::fs::read(config_path).unwrap()).unwrap(); + let device = Device::request_device(DeviceRequest::GPU).unwrap(); + + let encoder = WhisperEncoder::load(&header, &config, &mut reader, &device)?; + //let input = Tensor::read_npy::(input_npy, &device)?; + + let x = Tensor::from_data( + vec![5_f32; 1 * 1500 * 384], + shape![1, 1500, 384], + device.clone(), + ); + + let mut trace = encoder.blocks[0] + .schedule(ResidualAttentionBlockInputs { + x: x.clone(), + xa: None, + mask: None, + cache: None, + })? + .full()? + .trace()?; + + trace.iter_mut().for_each(|t| { + *t = t.to(&Device::CPU).unwrap(); + log::warn!("TRACE: {:?}", t); + }); + + trace.serialize(&device); + //let ours = result.to(&Device::CPU)?; + + //let ground = Tensor::read_npy::(ground_npy, &Device::CPU)?; + //println!("OURS: {:#?}", ours); + //println!("Ground: {:#?}", ground); + //ground.all_close(&ours, 1e-3, 1e-3)?; + + Ok(()) + } + #[test] fn encoder_matches() -> anyhow::Result<()> { log_init(); diff --git a/crates/ratchet-models/src/whisper/mod.rs b/crates/ratchet-models/src/whisper/mod.rs index bba35c1e..93061077 100644 --- a/crates/ratchet-models/src/whisper/mod.rs +++ b/crates/ratchet-models/src/whisper/mod.rs @@ -19,3 +19,4 @@ pub use config::Config; pub use decoder::WhisperDecoder; pub use encoder::WhisperEncoder; pub use model::Whisper; +pub use residual_block::ResidualAttentionBlockInputs; diff --git a/crates/ratchet-models/src/whisper/residual_block.rs b/crates/ratchet-models/src/whisper/residual_block.rs index 97d0fde3..c7757ceb 100644 --- a/crates/ratchet-models/src/whisper/residual_block.rs +++ b/crates/ratchet-models/src/whisper/residual_block.rs @@ -35,19 +35,20 @@ impl Module for ResidualAttentionBlock { self.attn .schedule(MHAInputs::new(attn_ln, None, mask.clone(), cache, true))?; - let mut attn = x.add(self_attn)?; + //let mut attn = x.add(self_attn)?; - if let Some(ref xa_blck) = self.x_attn { - if let Some(xa_ln) = &self.x_attn_ln { - let x_attn_ln = xa_ln.schedule(attn.clone())?; - let x_attn = - xa_blck.schedule(MHAInputs::new(x_attn_ln, xa.clone(), None, None, false))?; - attn = x_attn.add(attn.clone())?; - } - } - let mlp_ln = self.mlp_ln.schedule(attn.clone())?; - let mlp = self.mlp.schedule(mlp_ln)?; - mlp.add(attn) + //if let Some(ref xa_blck) = self.x_attn { + // if let Some(xa_ln) = &self.x_attn_ln { + // let x_attn_ln = xa_ln.schedule(attn.clone())?; + // let x_attn = + // xa_blck.schedule(MHAInputs::new(x_attn_ln, xa.clone(), None, None, false))?; + // attn = x_attn.add(attn.clone())?; + // } + //} + //let mlp_ln = self.mlp_ln.schedule(attn.clone())?; + //let mlp = self.mlp.schedule(mlp_ln)?; + //mlp.add(attn) + Ok(self_attn) } } diff --git a/crates/ratchet-models/tests/whisper.rs b/crates/ratchet-models/tests/whisper.rs index a2a44160..290ec3a0 100644 --- a/crates/ratchet-models/tests/whisper.rs +++ b/crates/ratchet-models/tests/whisper.rs @@ -6,7 +6,9 @@ use ndarray_stats::QuantileExt; use ratchet::{shape, Device, DeviceRequest, Tensor}; use ratchet_hub::{ApiBuilder, RepoType}; use ratchet_loader::gguf::gguf; -use ratchet_models::whisper::{Config, Whisper, WhisperDecoder, WhisperEncoder}; +use ratchet_models::whisper::{ + Config, ResidualAttentionBlockInputs, Whisper, WhisperDecoder, WhisperEncoder, +}; use ratchet_nn::Module; use std::path::PathBuf; use wasm_bindgen::prelude::*; @@ -18,6 +20,70 @@ fn log_init() { console_log::init_with_level(log::Level::Debug).unwrap(); } +#[wasm_bindgen_test] +async fn tiny_test() -> Result<(), JsValue> { + use include_dir::include_dir; + use ratchet::Trace; + log_init(); + let model_repo = ApiBuilder::from_hf("FL33TW00D-HF/whisper-tiny", RepoType::Model).build(); + let model_data = model_repo.get("tiny_f32.gguf").await?; + let config_data = model_repo.get("config.json").await?; + + let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build(); + let input_npy = ground_repo.get("jfk_tiny_encoder_input.npy").await?; + let ground_npy = ground_repo.get("jfk_tiny_encoder_hs.npy").await?; + + let mut reader = std::io::BufReader::new(std::io::Cursor::new(model_data.to_vec())); + let header = gguf::Header::read(&mut reader).unwrap(); + let config: Config = serde_json::from_slice(&config_data.to_vec()).unwrap(); + + let device = Device::request_device(DeviceRequest::GPU).await.unwrap(); + + //let input_data = &input_npy.to_vec(); + //let input = Tensor::from_npy_bytes::(input_data, &device).unwrap(); + //let ground = Tensor::from_npy_bytes::(&ground_npy.to_vec(), &Device::CPU).unwrap(); + + let encoder = WhisperEncoder::load(&header, &config, &mut reader, &device).unwrap(); + + let x = Tensor::from_data( + vec![5_f32; 1 * 1500 * 384], + shape![1, 1500, 384], + device.clone(), + ); + + let mut web_trace = encoder.blocks[0] + .schedule(ResidualAttentionBlockInputs { + x: x.clone(), + xa: None, + mask: None, + cache: None, + }) + .unwrap() + .full() + .unwrap() + .trace() + .unwrap(); + + for t in web_trace.iter_mut() { + *t = t.to(&Device::CPU).await.unwrap(); + log::warn!("TRACE: {:?}", t); + } + + //log::warn!("TRACE FIRST: {:?}", web_trace.first().unwrap()); + //let result = web_trace.pop().unwrap(); + + //log::warn!("TRACE RESULT: {:?}", result); + + let trace = + Trace::deserialize(include_dir!("/Users/fleetwood/Code/ratchet/crates/ratchet-models/trace-d3429424-55e4-453d-9aa6-fcb761307d8e-Apple-M3-Max-metal")).unwrap(); + trace.compare(&web_trace, 1e-3, 1e-3).unwrap(); + + //let ours = result.to(&Device::CPU).await.unwrap(); + //ground.all_close(&ours, 1e-3, 1e-3).unwrap(); + Ok(()) +} + +/* #[wasm_bindgen_test] async fn tiny_encoder() -> Result<(), JsValue> { use include_dir::include_dir; @@ -61,6 +127,7 @@ async fn tiny_encoder() -> Result<(), JsValue> { ground.all_close(&ours, 1e-3, 1e-3).unwrap(); Ok(()) } +*/ /* #[wasm_bindgen_test]