62 lines
2.0 KiB
Rust
62 lines
2.0 KiB
Rust
use sqlite_loadable::prelude::*;
|
|
use sqlite_loadable::{Error, ErrorKind};
|
|
use sqlite_loadable::{Result, api, define_scalar_function};
|
|
|
|
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
|
|
#[inline(always)]
|
|
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
|
|
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
|
|
}
|
|
|
|
if values.len() != 2 {
|
|
return Err(Error::new(ErrorKind::Message(
|
|
"cosine_similarity requires exactly 2 arguments".to_string(),
|
|
)));
|
|
}
|
|
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
|
|
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
|
|
let array_1_st =
|
|
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
|
|
let array_2_st =
|
|
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
|
|
|
|
let array_view_1 = array_1_st
|
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
|
.map_err(custom_error)?;
|
|
let array_view_2 = array_2_st
|
|
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
|
.map_err(custom_error)?;
|
|
|
|
use ndarray_math::*;
|
|
let similarity = array_view_1
|
|
.cosine_similarity(array_view_2)
|
|
.map_err(custom_error)?;
|
|
api::result_double(context, similarity as f64);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
|
|
define_scalar_function(
|
|
db,
|
|
"cosine_similarity",
|
|
2,
|
|
cosine_similarity,
|
|
FunctionFlags::DETERMINISTIC,
|
|
)?;
|
|
Ok(())
|
|
}
|
|
|
|
/// # Safety
|
|
///
|
|
/// Should only be called by underlying SQLite C APIs,
|
|
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
|
|
#[unsafe(no_mangle)]
|
|
pub unsafe extern "C" fn sqlite3_extension_init(
|
|
db: *mut sqlite3,
|
|
pz_err_msg: *mut *mut c_char,
|
|
p_api: *mut sqlite3_api_routines,
|
|
) -> c_uint {
|
|
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
|
|
}
|