use sqlite_loadable::prelude::*; use sqlite_loadable::{Error, ErrorKind}; use sqlite_loadable::{Result, api, define_scalar_function}; fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> { #[inline(always)] fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error { sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string())) } if values.len() != 2 { return Err(Error::new(ErrorKind::Message( "cosine_similarity requires exactly 2 arguments".to_string(), ))); } let array_1 = api::value_blob(values.get(0).expect("1st argument")); let array_2 = api::value_blob(values.get(1).expect("2nd argument")); let array_1_st = ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?; let array_2_st = ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?; let array_view_1 = array_1_st .tensor_by_index::(0) .map_err(custom_error)?; let array_view_2 = array_2_st .tensor_by_index::(0) .map_err(custom_error)?; use ndarray_math::*; let similarity = array_view_1 .cosine_similarity(array_view_2) .map_err(custom_error)?; api::result_double(context, similarity as f64); Ok(()) } pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> { define_scalar_function( db, "cosine_similarity", 2, cosine_similarity, FunctionFlags::DETERMINISTIC, )?; Ok(()) } /// # Safety /// /// Should only be called by underlying SQLite C APIs, /// like sqlite3_auto_extension and sqlite3_cancel_auto_extension. #[unsafe(no_mangle)] pub unsafe extern "C" fn sqlite3_extension_init( db: *mut sqlite3, pz_err_msg: *mut *mut c_char, p_api: *mut sqlite3_api_routines, ) -> c_uint { register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init) }