refactor: rename sqlite3-safetensor-cosine to sqlite3-ndarray-math
This commit is contained in:
14
sqlite3-ndarray-math/Cargo.toml
Normal file
14
sqlite3-ndarray-math/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "sqlite3-ndarray-math"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "staticlib"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = "0.16.1"
|
||||
ndarray-math = { git = "https://git.darksailor.dev/servius/ndarray-math", version = "0.1.0" }
|
||||
# ndarray-math = { path = "/Users/fs0c131y/Projects/ndarray-math", version = "0.1.0" }
|
||||
ndarray-safetensors = { version = "0.1.0", path = "../ndarray-safetensors" }
|
||||
sqlite-loadable = "0.0.5"
|
||||
61
sqlite3-ndarray-math/src/lib.rs
Normal file
61
sqlite3-ndarray-math/src/lib.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use sqlite_loadable::prelude::*;
|
||||
use sqlite_loadable::{Error, ErrorKind};
|
||||
use sqlite_loadable::{Result, api, define_scalar_function};
|
||||
|
||||
fn cosine_similarity(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
|
||||
#[inline(always)]
|
||||
fn custom_error(err: impl core::error::Error) -> sqlite_loadable::Error {
|
||||
sqlite_loadable::Error::new(sqlite_loadable::ErrorKind::Message(err.to_string()))
|
||||
}
|
||||
|
||||
if values.len() != 2 {
|
||||
return Err(Error::new(ErrorKind::Message(
|
||||
"cosine_similarity requires exactly 2 arguments".to_string(),
|
||||
)));
|
||||
}
|
||||
let array_1 = api::value_blob(values.get(0).expect("1st argument"));
|
||||
let array_2 = api::value_blob(values.get(1).expect("2nd argument"));
|
||||
let array_1_st =
|
||||
ndarray_safetensors::SafeArraysView::from_bytes(array_1).map_err(custom_error)?;
|
||||
let array_2_st =
|
||||
ndarray_safetensors::SafeArraysView::from_bytes(array_2).map_err(custom_error)?;
|
||||
|
||||
let array_view_1 = array_1_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(custom_error)?;
|
||||
let array_view_2 = array_2_st
|
||||
.tensor_by_index::<f32, ndarray::Ix1>(0)
|
||||
.map_err(custom_error)?;
|
||||
|
||||
use ndarray_math::*;
|
||||
let similarity = array_view_1
|
||||
.cosine_similarity(array_view_2)
|
||||
.map_err(custom_error)?;
|
||||
api::result_double(context, similarity as f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn _sqlite3_extension_init(db: *mut sqlite3) -> Result<()> {
|
||||
define_scalar_function(
|
||||
db,
|
||||
"cosine_similarity",
|
||||
2,
|
||||
cosine_similarity,
|
||||
FunctionFlags::DETERMINISTIC,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Should only be called by underlying SQLite C APIs,
|
||||
/// like sqlite3_auto_extension and sqlite3_cancel_auto_extension.
|
||||
#[unsafe(no_mangle)]
|
||||
pub unsafe extern "C" fn sqlite3_extension_init(
|
||||
db: *mut sqlite3,
|
||||
pz_err_msg: *mut *mut c_char,
|
||||
p_api: *mut sqlite3_api_routines,
|
||||
) -> c_uint {
|
||||
register_entrypoint(db, pz_err_msg, p_api, _sqlite3_extension_init)
|
||||
}
|
||||
Reference in New Issue
Block a user