Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4009d62
Add support for BF16 w/ `candle` (WIP)
alvarobartt Jan 20, 2026
4a9400e
Add support for `voyageai/voyage-4-*` (WIP)
alvarobartt Jan 20, 2026
e9b9332
Remove `half` crate dependency
alvarobartt Jan 20, 2026
d4f66bf
Add `tracing::warn!` on `DType::Bfloat16`
alvarobartt Jan 21, 2026
db7df49
Update `channel` to 1.92 in `rust-toolchain.toml`
alvarobartt Jan 21, 2026
6b2a1f1
Constrain `Bfloat16` on `candle` for Metal and CUDA
alvarobartt Jan 31, 2026
cec5de8
Update `tracing::warn!` for `DType::Bfloat16`
alvarobartt Jan 31, 2026
8889979
Fix `DType` enum and impl (never default to BF16)
alvarobartt Jan 31, 2026
35dc4d7
Add minimum finite value for `DType::BF16`
alvarobartt Jan 31, 2026
d7962ed
Restore `qwen3.rs` prior Voyage AI related changes
alvarobartt Jan 31, 2026
92c6ae7
Use `actions-rust-lang` and fix typo in "linting"
alvarobartt Jan 31, 2026
3838c04
Fix formatting in YAML files
alvarobartt Jan 31, 2026
04f5d63
Set default minimum finite value to FP16 value
alvarobartt Jan 31, 2026
b1a1b23
Calculate distance in Gemma3 w/ `abs_diff` instead (clippy)
alvarobartt Jan 31, 2026
b547b0b
Fix feature-gating for `DType::Bfloat16` (and exclude Turing and Volta)
alvarobartt Jan 31, 2026
43a8e99
Add note on lack of BF16 support for Turing (and earlier)
alvarobartt Jan 31, 2026
66b6a7b
Use `into_iter` for `Sequence` as `get_pre_tokenizers` unavailable
alvarobartt Jan 31, 2026
07fe2e7
Add runtime validation on CUDA compute cap for BF16
alvarobartt Jan 31, 2026
e506c93
Update `rustc` to latest stable 1.92
alvarobartt Jan 31, 2026
e3e34b8
Merge branch 'main' into add-bfloat16-support
alvarobartt Jan 31, 2026
eb4f13b
Revert `Default` impl for `DType` when `feature = "python"`
alvarobartt Jan 31, 2026
9745b0e
Add missing `cuda` feature flag to `candle-cuda` and `candle-cuda-tur…
alvarobartt Feb 1, 2026
9387804
Add BF16 support for `FlashQwen3`
alvarobartt Feb 1, 2026
29fe799
Remove `feature = "python"` from `tracing:warn!` on BF16
alvarobartt Feb 1, 2026
df60701
Add `supports_flash_attn` to remove duplicated code
alvarobartt Feb 1, 2026
b83102a
Update `index_select` to exclude CUDA + BF16
alvarobartt Feb 2, 2026
f6880c7
Skip BF16 support for CUDA (only Metal)
alvarobartt Feb 2, 2026
c19087e
Fix `index_select` feature gating
alvarobartt Feb 2, 2026
1d97366
Merge branch 'main' into add-bfloat16-support
alvarobartt Mar 30, 2026
a5ba1e4
Merge branch 'main' into add-bfloat16-support
alvarobartt Apr 29, 2026
65080b3
Bring `candle` related progress from `update-candle-wo-linking`
alvarobartt Apr 29, 2026
2ae1832
Update `Cargo.lock` and fix build on CUDA
alvarobartt Apr 29, 2026
a7272be
Fix default `{dynamic,static}-linking`
alvarobartt Apr 29, 2026
7fddd78
Add missing `candle-flash-attn-v1`
alvarobartt Apr 29, 2026
8419116
Set `default-features` to false for `candle-*`
alvarobartt Apr 29, 2026
5bc43f0
Run `cargo update`
alvarobartt Apr 29, 2026
f9fa7a2
Fix `static-linking` in `Dockerfile-cuda` (WIP)
alvarobartt Apr 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,803 changes: 1,601 additions & 1,202 deletions Cargo.lock

Large diffs are not rendered by default.

33 changes: 17 additions & 16 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,26 @@ serde_json = "1.0"
thiserror = "1.0"
rand = "0.9"
serial_test = "2.0.0"
cudarc = { version = "0.13", features =["cuda-12020"], default-features = false }
cudarc = { version = "0.19", features = ["cuda-version-from-build-system"], default-features = false }
intel-mkl-src = { version = "0.8", default-features = false }
candle = { version = "0.8", package = "candle-core" }
candle-nn = { version = "0.8" }
candle-transformers = { version = "0.8" }
candle-flash-attn = { version = "0.8" }
candle-cublaslt = { version = "0.0.1" }
candle-layer-norm = { version = "0.0.1" }
candle-index-select-cu = { version = "0.0.1", features = ["cuda-11"], default-features = false }
candle-rotary = { version = "0.0.1" }
candle-flash-attn-v1 = { version = "0.0.1" }
half = { version = "2.3.1", features = ["num-traits"] }
candle = { version = "0.9.2", package = "candle-core" }
candle-nn = { version = "0.9.2" }
candle-transformers = { version = "0.9.2" }
candle-flash-attn = { version = "0.9.2" }
candle-cublaslt = { version = "0.0.1", default-features = false }
candle-layer-norm = { version = "0.0.1", default-features = false }
candle-rotary = { version = "0.0.1", default-features = false }
candle-flash-attn-v1 = { version = "0.0.1", default-features = false }

[patch.crates-io]
cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9"}
candle = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-core" }
candle-nn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-transformers" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-flash-attn" }
candle = { git = "https://github.com/huggingface/candle", branch = "no-default-linking", package = "candle-core" }
candle-nn = { git = "https://github.com/huggingface/candle", branch = "no-default-linking", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle", branch = "no-default-linking", package = "candle-transformers" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", branch = "no-default-linking", package = "candle-flash-attn" }
candle-cublaslt = { git = "https://github.com/huggingface/candle-extensions", branch = "allow-static-linking" }
candle-layer-norm = { git = "https://github.com/huggingface/candle-extensions", branch = "allow-static-linking" }
candle-rotary = { git = "https://github.com/huggingface/candle-extensions", branch = "allow-static-linking" }
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-extensions", branch = "allow-static-linking" }

[profile.release]
debug = 0
Expand Down
5 changes: 4 additions & 1 deletion Dockerfile-cuda
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ ENV CARGO_CHEF=0.1.73

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
git \
libssl-dev \
libstdc++-13-dev \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
&& rm -rf /var/lib/apt/lists/* \
&& ln -sf "$(gcc --print-file-name=libstdc++.a)" "/usr/lib/$(gcc -print-multiarch)/libstdc++.a"

# Download and configure sccache (multi-arch)
ARG TARGETARCH
Expand Down
2 changes: 2 additions & 0 deletions backends/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ mkl = ["text-embeddings-backend-candle?/mkl"]
accelerate = ["text-embeddings-backend-candle?/accelerate"]
flash-attn = ["text-embeddings-backend-candle?/flash-attn"]
flash-attn-v1 = ["text-embeddings-backend-candle?/flash-attn-v1"]
static-linking = ["text-embeddings-backend-candle?/static-linking"]
dynamic-linking = ["text-embeddings-backend-candle?/dynamic-linking"]
7 changes: 4 additions & 3 deletions backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true}
candle-flash-attn-v1 = { workspace = true, optional = true }
candle-cublaslt = { workspace = true, optional = true }
candle-index-select-cu = { workspace = true, optional = true, features = ["cuda-11"], default-features = false}
candle-layer-norm = { workspace = true, optional = true }
candle-rotary = { workspace = true, optional = true }
nohash-hasher = { workspace = true }
Expand All @@ -41,7 +40,9 @@ anyhow = { version = "1", features = ["backtrace"] }
[features]
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
metal = ["candle/metal", "candle-nn/metal"]
mkl = ["dep:intel-mkl-src", "candle/_mkl"]
cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary", "dep:candle-index-select-cu"]
mkl = ["dep:intel-mkl-src", "candle/mkl-enabled"]
cuda = ["candle/cuda-enabled", "candle-nn/cuda-enabled", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"]
flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"]
flash-attn = ["dep:candle-flash-attn", "cuda"]
static-linking = ["candle-cublaslt?/static-linking", "candle-layer-norm?/static-linking", "candle-rotary?/static-linking", "candle-flash-attn-v1?/static-linking"]
dynamic-linking = ["candle-cublaslt?/dynamic-linking", "candle-layer-norm?/dynamic-linking", "candle-rotary?/dynamic-linking", "candle-flash-attn-v1?/dynamic-linking"]
8 changes: 4 additions & 4 deletions backends/candle/src/compute_cap.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use anyhow::Context;

use candle::cuda_backend::cudarc::driver;
use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::{
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
};
use candle::cuda_backend::cudarc::driver::CudaDevice;

pub fn get_compile_compute_cap() -> Result<usize, anyhow::Error> {
env!("CUDA_COMPUTE_CAP")
Expand All @@ -13,11 +13,11 @@ pub fn get_compile_compute_cap() -> Result<usize, anyhow::Error> {

pub fn get_runtime_compute_cap() -> Result<usize, anyhow::Error> {
driver::result::init().context("CUDA is not available")?;
let device = CudaDevice::new(0).context("CUDA is not available")?;
let major = device
let context = driver::CudaContext::new(0).context("CUDA is not available")?;
let major = context
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
.context("Could not retrieve device compute capability major")?;
let minor = device
let minor = context
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
.context("Could not retrieve device compute capability minor")?;
Ok((major * 10 + minor) as usize)
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Once;

static INIT: Once = Once::new();
static mut RUNTIME_COMPUTE_CAP: usize = 0;

fn init_runtime_compute_cap() {
unsafe {
INIT.call_once(|| {
Expand Down
14 changes: 1 addition & 13 deletions backends/candle/src/layers/index_select.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
// SPDX-License-Identifier: MIT or Apache-2.0
// First Published under RadixMLP and https://github.com/michaelfeil/candle-index-select-cu by Michael Feil

use candle::{Result, Tensor};
#[cfg(feature = "cuda")]
use candle_index_select_cu;

#[inline]
#[allow(dead_code)]
pub fn index_select(tensor: &Tensor, ids: &Tensor, dim: usize) -> Result<Tensor> {
#[cfg(not(feature = "cuda"))]
{
tensor.index_select(ids, dim)
}
#[cfg(feature = "cuda")]
{
candle_index_select_cu::index_select(tensor, ids, dim)
}
tensor.index_select(ids, dim)
}
42 changes: 41 additions & 1 deletion backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,42 @@ impl CandleBackend {
}
.map_err(|err| BackendError::Start(err.to_string()))?;

// Get candle dtype
let dtype = if &dtype == "float32" {
Ok(DType::F32)
} else if &dtype == "float16" {
Ok(DType::F16)
} else if &dtype == "bfloat16" {
match &device {
Device::Cpu => {
return Err(BackendError::Start(
"BFloat16 is not supported on CPU. Use float16 or float32 instead."
.to_string(),
));
}
Device::Cuda(_) => {
return Err(BackendError::Start(
"CUDA feature is not enabled".to_string(),
));
}
// NOTE: Temporarily left out given that supporting BF16 w/ Flash Attn requires an
// update on `candle` and `candle-extensions` which is still in progress
// #[cfg(feature = "cuda")]
// Device::Cuda(_) => {
// let compute_cap = get_runtime_compute_cap().map_err(|e| {
// BackendError::Start(format!("Failed to get CUDA compute capability: {e:?}"))
// })?;
// if compute_cap < 80 {
// return Err(BackendError::Start(format!(
// "BFloat16 requires CUDA compute capability >= 8.0 (Ampere or newer), \
// but found {}.{}. Use float16 or float32 instead.",
// compute_cap / 10,
// compute_cap % 10
// )));
// }
// }
Device::Metal(_) => (),
}
Ok(DType::BF16)
} else {
Err(BackendError::Start(format!(
"DType {dtype} is not supported"
Expand Down Expand Up @@ -377,6 +408,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V1, FlashAttn::V2]) {
match config {
BertConfigWrapper::JinaBert(config) => {
Expand Down Expand Up @@ -420,6 +452,7 @@ impl CandleBackend {
Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config),
Device::Cuda(_),
) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V1, FlashAttn::V2]) {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -439,6 +472,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::DistilBert(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V2]) {
tracing::info!("Starting FlashDistilBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -465,6 +499,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Gte(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V1, FlashAttn::V2]) {
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
Expand All @@ -475,6 +510,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Mistral(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if !(dtype == DType::F16 && use_flash_attn(&[FlashAttn::V2])) {
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
Expand All @@ -485,6 +521,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::ModernBert(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V2]) {
tracing::info!("Starting FlashModernBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -501,6 +538,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::NomicBert(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V2]) {
tracing::info!("Starting FlashNomicBert model on {:?}", device);
Ok(Box::new(
Expand All @@ -513,6 +551,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Qwen2(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if !(dtype == DType::F16 && use_flash_attn(&[FlashAttn::V1, FlashAttn::V2])) {
return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
Expand All @@ -523,6 +562,7 @@ impl CandleBackend {
}
#[cfg(feature = "cuda")]
(Config::Qwen3(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Include the `dtype` as an arg in `use_flash_attn`
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V1, FlashAttn::V2]) {
tracing::info!("Starting FlashQwen3 model on {:?}", device);
Ok(Box::new(
Expand Down
6 changes: 5 additions & 1 deletion backends/candle/src/models/gemma3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,11 @@ impl Gemma3Attention {
) -> Result<Tensor> {
let min_value = match dtype {
DType::F32 => f32::MIN,
_ => -65504.0, // f16 minimum value
DType::BF16 => -3.3895314e38_f32,
DType::F16 => -65504.0_f32,
// SAFETY: Default to F16 min finite value, even if dtype will always match any of the
// previous variants
_ => -65504.0_f32,
};

let mask: Vec<u8> = (0..seq_len)
Expand Down
6 changes: 5 additions & 1 deletion backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,11 @@ impl ModernBertModel {

let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0, // f16 minimum value
DType::BF16 => -3.3895314e38_f64,
DType::F16 => -65504.0_f64,
// SAFETY: Default to F16 min finite value, even if dtype will always match any of the
// previous variants
_ => -65504.0_f64,
};

let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?;
Expand Down
6 changes: 5 additions & 1 deletion backends/candle/src/models/mpnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ impl MPNetModel {

let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0_f64, // f16 minumum value
DType::BF16 => -3.3895314e38_f64,
DType::F16 => -65504.0_f64,
// SAFETY: Default to F16 min finite value, even if dtype will always match any of the
// previous variants
_ => -65504.0_f64,
};

let extended_attention_mask = ((1.0 - extended_attention_mask)? * min_value)?;
Expand Down
6 changes: 5 additions & 1 deletion backends/candle/src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,11 @@ impl Qwen3Model {

let min_value = match self.dtype {
DType::F32 => f32::MIN,
_ => -65504.0, // f16 minimum value
DType::BF16 => -3.3895314e38_f32,
DType::F16 => -65504.0_f32,
// SAFETY: Default to F16 min finite value, even if dtype will always match any of the
// previous variants
_ => -65504.0_f32,
};

let negatives =
Expand Down
4 changes: 2 additions & 2 deletions backends/candle/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,14 @@ pub fn load_tokenizer(model_root: &Path) -> Result<Tokenizer> {
m.set_prepend_scheme(PrependScheme::First);
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Metaspace(m)));
} else if let PreTokenizerWrapper::Sequence(s) = pre_tokenizer {
let pre_tokenizers = s.get_pre_tokenizers();
let pre_tokenizers: Vec<_> = s.clone().into_iter().collect();
// Check if we have a Metaspace pre tokenizer in the sequence
let has_metaspace = pre_tokenizers
.iter()
.any(|t| matches!(t, PreTokenizerWrapper::Metaspace(_)));

if has_metaspace {
let mut new_pre_tokenizers = Vec::with_capacity(s.get_pre_tokenizers().len());
let mut new_pre_tokenizers = Vec::with_capacity(pre_tokenizers.len());

for pre_tokenizer in pre_tokenizers {
if let PreTokenizerWrapper::WhitespaceSplit(_) = pre_tokenizer {
Expand Down
36 changes: 31 additions & 5 deletions backends/src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,61 @@
use std::fmt;
use std::{fmt, str::FromStr};

#[cfg(feature = "clap")]
use clap::ValueEnum;

#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "clap", derive(Clone, ValueEnum))]
pub enum DType {
// Float16 is not available on accelerate
#[cfg(any(
feature = "python",
all(feature = "candle", not(feature = "accelerate"))
))]
Float16,
#[cfg(any(feature = "python", feature = "candle", feature = "ort"))]
Float32,
#[cfg(feature = "python")]
// NOTE: For CUDA, BF16 requires Ampere (SM 80) or newer, which is validated at runtime, as
// there are no specific features for the different CUDA compute capabilities to filter out
// Turing and Volta from having `DType::Bfloat16`.
// NOTE: At the moment only Intel HPU and Metal are supported, given that there are still a few
// missing pieces to update `candle` and `candle-extensions` w/ support for BF16 Flash Attn
#[cfg(any(feature = "python", all(feature = "candle", feature = "metal")))]
Bfloat16,
}

#[derive(Debug, PartialEq, Eq)]
pub struct DTypeParseError;

impl FromStr for DType {
type Err = DTypeParseError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let dtype = match s {
"float32" => DType::Float32,
#[cfg(any(
feature = "python",
all(feature = "candle", not(feature = "accelerate"))
))]
"float16" => DType::Float16,
#[cfg(any(feature = "python", all(feature = "candle", feature = "metal")))]
"bfloat16" => DType::Bfloat16,
_ => return Err(DTypeParseError),
};

Ok(dtype)
}
}

impl fmt::Display for DType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
// Float16 is not available on accelerate
#[cfg(any(
feature = "python",
all(feature = "candle", not(feature = "accelerate"))
))]
DType::Float16 => write!(f, "float16"),
#[cfg(any(feature = "python", feature = "candle", feature = "ort"))]
DType::Float32 => write!(f, "float32"),
#[cfg(feature = "python")]
#[cfg(any(feature = "python", all(feature = "candle", feature = "metal")))]
DType::Bfloat16 => write!(f, "bfloat16"),
}
}
Expand Down
4 changes: 2 additions & 2 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ candle = ["text-embeddings-backend/candle"]
candle-cuda = ["candle", "text-embeddings-backend/flash-attn", "dep:cudarc"]
candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1", "dep:cudarc"]
candle-cuda-volta = ["candle", "text-embeddings-backend/cuda", "dep:cudarc"]
static-linking = ["cudarc?/static-linking", "intel-mkl-src?/mkl-static-lp64-iomp"]
dynamic-linking = ["cudarc?/dynamic-linking", "intel-mkl-src?/mkl-dynamic-lp64-iomp"]
static-linking = ["cudarc?/static-linking", "intel-mkl-src?/mkl-static-lp64-iomp", "text-embeddings-backend/static-linking"]
dynamic-linking = ["cudarc?/dynamic-linking", "intel-mkl-src?/mkl-dynamic-lp64-iomp", "text-embeddings-backend/dynamic-linking"]
google = []
Loading
Loading