use std::{
any::{Any, TypeId},
cell::UnsafeCell,
fmt::Debug,
future::Future,
mem::ManuallyDrop,
pin::Pin,
sync::Arc,
task::{Context, RawWaker, RawWakerVTable, Waker},
};
use crossbeam_channel::{Receiver, Sender};
#[derive(Clone)]
pub struct CommandWaker(Arc<dyn Fn() + Send + Sync>);
impl CommandWaker {
pub fn new(wake: impl Fn() + Send + Sync + 'static) -> Self {
Self(Arc::new(wake))
}
pub fn wake(&self) {
(self.0)();
}
}
impl From<Arc<dyn Fn() + Send + Sync>> for CommandWaker {
fn from(wake: Arc<dyn Fn() + Send + Sync>) -> Self {
Self(wake)
}
}
impl Debug for CommandWaker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommandWaker").finish()
}
}
pub struct Command {
type_id: TypeId,
data: Box<dyn Any + Send>,
name: &'static str,
}
impl Command {
pub fn new<T: Any + Send>(command: T) -> Self {
Self {
type_id: TypeId::of::<T>(),
data: Box::new(command),
name: std::any::type_name::<T>(),
}
}
pub fn is<T: Any>(&self) -> bool {
self.type_id == TypeId::of::<T>()
}
pub fn get<T: Any>(&self) -> Option<&T> {
if self.is::<T>() {
unsafe { Some(&*(self.data.as_ref() as *const _ as *const T)) }
} else {
None
}
}
pub fn name(&self) -> &'static str {
self.name
}
pub fn to_any(self) -> Box<dyn Any + Send> {
self.data
}
}
impl Debug for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Command").field("type", &self.name).finish()
}
}
#[derive(Clone)]
pub struct CommandProxy {
tx: Sender<Command>,
waker: CommandWaker,
}
impl CommandProxy {
pub fn new(waker: CommandWaker) -> (Self, CommandReceiver) {
let (tx, rx) = crossbeam_channel::unbounded();
(Self { tx, waker }, CommandReceiver { rx })
}
pub fn wake(&self) {
self.waker.wake();
}
pub fn cmd_silent(&self, command: Command) {
if let Err(err) = self.tx.send(command) {
tracing::warn!("failed to send command: {}", err);
}
}
pub fn cmd(&self, command: impl Any + Send) {
self.cmd_silent(Command::new(command));
self.wake();
}
pub fn spawn_async(&self, future: impl Future<Output = ()> + Send + 'static) {
let task = Arc::new(CommandTask::new(self, future));
unsafe { task.poll() };
}
pub fn cmd_async<T: Any + Send>(&self, future: impl Future<Output = T> + Send + 'static) {
let proxy = self.clone();
self.spawn_async(async move {
proxy.cmd(future.await);
});
}
}
impl Debug for CommandProxy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommandProxy").finish()
}
}
pub struct CommandReceiver {
rx: Receiver<Command>,
}
impl CommandReceiver {
fn try_recv_inner(&self) -> Option<Command> {
self.rx.try_recv().ok()
}
pub fn try_recv(&self) -> Option<Command> {
let mut command = self.try_recv_inner()?;
while let Some(task) = command.get::<CommandTaskShared>() {
unsafe { task.poll() };
command = self.try_recv_inner()?;
}
Some(command)
}
}
impl Debug for CommandReceiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommandReceiver").finish()
}
}
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
struct CommandTask {
proxy: CommandProxy,
future: UnsafeCell<Option<BoxFuture<'static, ()>>>,
}
type CommandTaskShared = Arc<CommandTask>;
unsafe impl Sync for CommandTask {}
impl CommandTask {
fn new(proxy: &CommandProxy, future: impl Future<Output = ()> + Send + 'static) -> Self {
Self {
proxy: proxy.clone(),
future: UnsafeCell::new(Some(Box::pin(future))),
}
}
fn raw_waker_vtable() -> &'static RawWakerVTable {
&RawWakerVTable::new(
CommandTask::waker_clone,
CommandTask::waker_wake,
CommandTask::waker_wake_by_ref,
CommandTask::waker_drop,
)
}
unsafe fn increase_refcount(data: *const ()) {
let arc = ManuallyDrop::new(Arc::from_raw(data.cast::<Self>()));
let _arc_clone = arc.clone();
}
unsafe fn waker_clone(data: *const ()) -> RawWaker {
Self::increase_refcount(data);
RawWaker::new(data, Self::raw_waker_vtable())
}
unsafe fn waker_wake(data: *const ()) {
let arc = Arc::from_raw(data.cast::<Self>());
arc.proxy.cmd(arc.clone());
}
unsafe fn waker_wake_by_ref(data: *const ()) {
let arc = ManuallyDrop::new(Arc::from_raw(data.cast::<Self>()));
let task: Arc<Self> = (*arc).clone();
arc.proxy.cmd(task);
}
unsafe fn waker_drop(data: *const ()) {
drop(Arc::from_raw(data.cast::<Self>()));
}
fn raw_waker(self: &CommandTaskShared) -> RawWaker {
let data = CommandTaskShared::into_raw(self.clone());
RawWaker::new(data.cast(), Self::raw_waker_vtable())
}
unsafe fn poll(self: &CommandTaskShared) {
let future_slot = &mut *self.future.get();
if let Some(mut future) = future_slot.take() {
let waker = Waker::from_raw(self.raw_waker());
let mut cx = Context::from_waker(&waker);
if future.as_mut().poll(&mut cx).is_pending() {
*future_slot = Some(future);
}
}
}
}