refactor: rename sqlite3-safetensor-cosine to sqlite3-ndarray-math
Some checks failed
build / checks-matrix (push) Successful in 19m20s
build / codecov (push) Failing after 19m22s
docs / docs (push) Failing after 28m47s
build / checks-build (push) Has been cancelled

This commit is contained in:
uttarayan21
2025-08-28 18:42:35 +05:30
parent ac8f1d01b4
commit aaf34ef74e
4 changed files with 5 additions and 5 deletions

View File

@@ -0,0 +1,14 @@
[package]
name = "sqlite3-ndarray-math"
version.workspace = true
edition.workspace = true
[lib]
crate-type = ["cdylib", "staticlib"]
[dependencies]
ndarray = "0.16.1"
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
# ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
sqlite-loadable = "0.0.5"

View File

@@ -0,0 +1,61 @@
use sqlite_loadable::prelude::*;
use sqlite_loadable::{Error, ErrorKind};
use sqlite_loadable::{Result, api, define_scalar_function};
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
#[inline(always)]
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
}
if values.len() != 2 {
return Err(Error::new(ErrorKind::Message(
"cosine_similarity requires exactly 2 arguments".to_string(),
)));
}
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
let array_1_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
let array_2_st =
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
let array_view_1 = array_1_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
let array_view_2 = array_2_st
.tensor_by_index::<f32, ndarray::Ix1>(0)
.map_err(custom_error)?;
use ndarray_math::*;
let similarity = array_view_1
.cosine_similarity(array_view_2)
.map_err(custom_error)?;
api::result_double(context, similarity as f64);
Ok(())
}
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
define_scalar_function(
db,
"cosine_similarity",
2,
cosine_similarity,
FunctionFlags::DETERMINISTIC,
)?;
Ok(())
}
/// # Safety
///
/// Should only be called by underlying SQLite C APIs,
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn sqlite3_extension_init(
db: *mut sqlite3,
pz_err_msg: *mut *mut c_char,
p_api: *mut sqlite3_api_routines,
) -> c_uint {
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
}