use std::ffi::{c_int, c_uchar, c_uint, c_void};
use std::{mem, slice};
use log::error;
use mariadb_sys as bindings;
use super::encryption::{Action, Decryption, Encryption, Flags, KeyError, KeyManager};
pub trait WrapKeyMgr: KeyManager {
extern "C" fn wrap_get_latest_key_version(key_id: c_uint) -> c_uint {
match Self::get_latest_key_version(key_id) {
Ok(v) => {
if v == bindings::ENCRYPTION_KEY_NOT_ENCRYPTED {
error!("get_latest_key_version returned value {v}, which is reserved for unencrypted \
keys. The server will likely restart.");
} else if v == bindings::ENCRYPTION_KEY_VERSION_INVALID {
panic!("get_latest_key_version returned value {v}, which is reserved for invalid keys. \
Return an Err if this is intended.");
}
v
}
Err(_) => KeyError::InvalidVersion as c_uint,
}
}
unsafe extern "C" fn wrap_get_key(
key_id: c_uint,
version: c_uint,
dstbuf: *mut c_uchar,
buflen: *mut c_uint,
) -> c_uint {
let key_len = match Self::key_length(key_id, version) {
Ok(v) => v,
Err(e) => {
return e as c_uint;
}
};
if dstbuf.is_null() {
*buflen = key_len.try_into().unwrap();
return bindings::ENCRYPTION_KEY_BUFFER_TOO_SMALL;
}
let buf = slice::from_raw_parts_mut(dstbuf, (*buflen).try_into().unwrap());
let Some(sized_buf) = buf.get_mut(..key_len) else {
error!(
"key requires {key_len} bytes but received only {}",
buf.len()
);
return bindings::ENCRYPTION_KEY_BUFFER_TOO_SMALL;
};
let (ret, new_buflen) = match Self::get_key(key_id, version, sized_buf) {
Ok(()) => (0, key_len.try_into().unwrap()),
Err(e) => (e as c_uint, 0),
};
*buflen = new_buflen;
ret
}
}
impl<T> WrapKeyMgr for T where T: KeyManager {}
#[repr(C)]
enum CryptCtxWrapper<En, De> {
Encrypt(En),
Decrypt(De),
}
pub extern "C" fn wrap_crypt_ctx_size<En: Encryption, De: Decryption>(
_key_id: c_uint,
_key_version: c_uint,
) -> c_uint {
mem::size_of::<CryptCtxWrapper<En, De>>()
.try_into()
.unwrap()
}
pub unsafe extern "C" fn wrap_crypt_ctx_init<En: Encryption, De: Decryption>(
ctx: *mut c_void,
key: *const c_uchar,
klen: c_uint,
iv: *const c_uchar,
ivlen: c_uint,
flags: c_int,
key_id: c_uint,
key_version: c_uint,
) -> c_int {
let keybuf = slice::from_raw_parts(key, klen.try_into().unwrap());
let ivbuf = slice::from_raw_parts(iv, ivlen.try_into().unwrap());
let flags = Flags::new(flags);
let same_size = flags.nopad();
let init_res =
match flags.action() {
Action::Encrypt => En::init(key_id, key_version, keybuf, ivbuf, same_size)
.map(CryptCtxWrapper::Encrypt),
Action::Decrypt => De::init(key_id, key_version, keybuf, ivbuf, same_size)
.map(CryptCtxWrapper::Decrypt),
};
let newctx = match init_res {
Ok(c) => c,
Err(e) => return e as c_int,
};
ctx.cast::<CryptCtxWrapper<En, De>>().write(newctx);
bindings::MY_AES_OK.try_into().unwrap()
}
pub unsafe extern "C" fn wrap_crypt_ctx_update<En: Encryption, De: Decryption>(
ctx: *mut c_void,
src: *const c_uchar,
slen: c_uint,
dst: *mut c_uchar,
dlen: *mut c_uint,
) -> c_int {
let sbuf = slice::from_raw_parts(src, slen.try_into().unwrap());
let dbuf = slice::from_raw_parts_mut(dst, slen.try_into().unwrap());
let this: &mut CryptCtxWrapper<En, De> = &mut *ctx.cast();
let update_res = match this {
CryptCtxWrapper::Encrypt(v) => v.update(sbuf, dbuf),
CryptCtxWrapper::Decrypt(v) => v.update(sbuf, dbuf),
};
let (ret, written) = match update_res {
Ok(v) => (
bindings::MY_AES_OK.try_into().unwrap(),
v.try_into().unwrap(),
),
Err(e) => (e as c_int, 0),
};
*dlen = written;
ret
}
pub unsafe extern "C" fn wrap_crypt_ctx_finish<En: Encryption, De: Decryption>(
ctx: *mut c_void,
dst: *mut c_uchar,
dlen: *mut c_uint,
) -> c_int {
let dbuf = slice::from_raw_parts_mut(dst, (*dlen).try_into().unwrap());
let this: &mut CryptCtxWrapper<En, De> = &mut *ctx.cast();
let finish_res = match this {
CryptCtxWrapper::Encrypt(v) => v.finish(dbuf),
CryptCtxWrapper::Decrypt(v) => v.finish(dbuf),
};
let (ret, written) = match finish_res {
Ok(v) => (
bindings::MY_AES_OK.try_into().unwrap(),
v.try_into().unwrap(),
),
Err(e) => (e as c_int, 0),
};
*dlen = written;
ctx.drop_in_place();
ret
}
pub unsafe extern "C" fn wrap_encrypted_length<En: Encryption>(
slen: c_uint,
key_id: c_uint,
key_version: c_uint,
) -> c_uint {
En::encrypted_length(key_id, key_version, slen.try_into().unwrap())
.try_into()
.unwrap()
}
#[allow(dead_code)]
unsafe fn set_buflen_with_check(buflen: *mut c_uint, val: u32) {
if val > 32 {
error!(
"The default encryption does not seem to allow keys above 32 bits. If the server \
crashes after this message, that is the likely error"
);
}
*buflen = val.try_into().unwrap();
}