refactor to use a set of traits for transport
(i want to try and make this use webtransport)
This commit is contained in:
parent
8438ca11b5
commit
9248fe91a9
5 changed files with 329 additions and 230 deletions
5
server/Cargo.lock
generated
5
server/Cargo.lock
generated
|
@ -34,9 +34,9 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.82"
|
version = "0.1.83"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1"
|
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -1406,6 +1406,7 @@ name = "vncstream_server"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"cudarc",
|
"cudarc",
|
||||||
"ffmpeg-next",
|
"ffmpeg-next",
|
||||||
|
|
|
@ -31,6 +31,7 @@ cudarc = { version = "0.12.1", features = [ "cuda-11050" ] }
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
tracing-subscriber = "0.3.18"
|
tracing-subscriber = "0.3.18"
|
||||||
xkeysym = "0.2.1"
|
xkeysym = "0.2.1"
|
||||||
|
async-trait = "0.1.83"
|
||||||
|
|
||||||
|
|
||||||
[patch.crates-io]
|
[patch.crates-io]
|
||||||
|
|
|
@ -3,7 +3,12 @@ mod surface;
|
||||||
mod types;
|
mod types;
|
||||||
mod video;
|
mod video;
|
||||||
|
|
||||||
|
mod transport;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use retro_thread::{spawn_retro_thread, RetroEvent};
|
use retro_thread::{spawn_retro_thread, RetroEvent};
|
||||||
|
use transport::{Transport, TransportReciever};
|
||||||
use video::encoder_thread::EncodeThreadInput;
|
use video::encoder_thread::EncodeThreadInput;
|
||||||
use video::{encoder_thread, ffmpeg};
|
use video::{encoder_thread, ffmpeg};
|
||||||
|
|
||||||
|
@ -43,22 +48,135 @@ struct AppState {
|
||||||
encoder_tx: Arc<TokioMutex<mpsc::Sender<EncodeThreadInput>>>,
|
encoder_tx: Arc<TokioMutex<mpsc::Sender<EncodeThreadInput>>>,
|
||||||
inputs: Arc<TokioMutex<Vec<u32>>>,
|
inputs: Arc<TokioMutex<Vec<u32>>>,
|
||||||
|
|
||||||
websocket_broadcast_tx: broadcast::Sender<WsMessage>,
|
transport: Arc<crate::transport::websocket::WebsocketTransport>,
|
||||||
websocket_count: TokioMutex<usize>,
|
connection_count: TokioMutex<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
fn new(encoder_tx: mpsc::Sender<EncodeThreadInput>) -> Self {
|
fn new(
|
||||||
let (broadcast_tx, _) = broadcast::channel(10);
|
encoder_tx: mpsc::Sender<EncodeThreadInput>,
|
||||||
|
transport: Arc<crate::transport::websocket::WebsocketTransport>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
encoder_tx: Arc::new(TokioMutex::new(encoder_tx)),
|
encoder_tx: Arc::new(TokioMutex::new(encoder_tx)),
|
||||||
inputs: Arc::new(TokioMutex::new(Vec::new())),
|
inputs: Arc::new(TokioMutex::new(Vec::new())),
|
||||||
websocket_broadcast_tx: broadcast_tx,
|
transport: transport,
|
||||||
websocket_count: TokioMutex::const_new(0usize),
|
connection_count: TokioMutex::const_new(0usize),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TransportReciever for AppState {
|
||||||
|
async fn on_connect(&self, username: &String) -> anyhow::Result<()> {
|
||||||
|
println!("{username} joined!");
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut lk = self.connection_count.lock().await;
|
||||||
|
*lk += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let locked = self.encoder_tx.lock().await;
|
||||||
|
|
||||||
|
// Force a ws connection to mean a keyframe
|
||||||
|
let _ = locked.send(EncodeThreadInput::ForceKeyframe).await;
|
||||||
|
let _ = locked.send(EncodeThreadInput::SendFrame).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()> {
|
||||||
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||||
|
if !json["type"].is_string() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match json["type"].as_str().unwrap() {
|
||||||
|
"chat" => {
|
||||||
|
if !json["msg"].is_string() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let send = serde_json::json!({
|
||||||
|
"type": "chat",
|
||||||
|
"username": username,
|
||||||
|
"msg": json["msg"].as_str().unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
self.transport
|
||||||
|
.broadcast_message(transport::TransportMessage::Text(
|
||||||
|
serde_json::to_string(&send).expect("oh well"),
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
"key" => {
|
||||||
|
if !json["keysym"].is_number() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !json["pressed"].is_number() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let keysym = json["keysym"].as_u64().unwrap() as u32;
|
||||||
|
let pressed = json["pressed"].as_u64().unwrap() == 1;
|
||||||
|
|
||||||
|
// FIXME: This would be MUCH better off being a set, so we don't
|
||||||
|
// hack-code set semantics here. Oh well.
|
||||||
|
{
|
||||||
|
let mut lock = self.inputs.lock().await;
|
||||||
|
if pressed {
|
||||||
|
if let None = lock.iter().position(|e| *e == keysym) {
|
||||||
|
lock.push(keysym);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if let Some(at) = lock.iter().position(|e| *e == keysym) {
|
||||||
|
lock.remove(at);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"mouse" => {
|
||||||
|
if json["x"].as_u64().is_none() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if json["y"].as_u64().is_none() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if json["mask"].as_u64().is_none() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
//let x = json["x"].as_u64().unwrap() as u32;
|
||||||
|
//let y = json["y"].as_u64().unwrap() as u32;
|
||||||
|
//let mask = json["mask"].as_u64().unwrap() as u8;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn on_leave(&self, username: &String) -> anyhow::Result<()> {
|
||||||
|
{
|
||||||
|
let mut lk = self.connection_count.lock().await;
|
||||||
|
*lk -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{username} left.");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
// Setup a tracing subscriber
|
// Setup a tracing subscriber
|
||||||
|
@ -74,12 +192,16 @@ async fn main() -> anyhow::Result<()> {
|
||||||
let frame: Arc<Mutex<Option<ffmpeg::frame::Video>>> = Arc::new(Mutex::new(None));
|
let frame: Arc<Mutex<Option<ffmpeg::frame::Video>>> = Arc::new(Mutex::new(None));
|
||||||
let (mut encoder_rx, encoder_tx) = encoder_thread::encoder_thread_spawn(&frame);
|
let (mut encoder_rx, encoder_tx) = encoder_thread::encoder_thread_spawn(&frame);
|
||||||
|
|
||||||
let state = Arc::new(AppState::new(encoder_tx));
|
let transport = Arc::new(crate::transport::websocket::WebsocketTransport::new());
|
||||||
|
|
||||||
|
let state = Arc::new(AppState::new(encoder_tx, transport.clone()));
|
||||||
|
|
||||||
let (mut event_rx, event_in_tx) = spawn_retro_thread(surface.clone());
|
let (mut event_rx, event_in_tx) = spawn_retro_thread(surface.clone());
|
||||||
|
|
||||||
let state_clone = state.clone();
|
let state_clone = state.clone();
|
||||||
|
|
||||||
|
let transport_clone = transport.clone();
|
||||||
|
|
||||||
// retro event handler. drives the encoder thread too
|
// retro event handler. drives the encoder thread too
|
||||||
let _ = std::thread::Builder::new()
|
let _ = std::thread::Builder::new()
|
||||||
.name("retro_event_rx".into())
|
.name("retro_event_rx".into())
|
||||||
|
@ -168,9 +290,18 @@ async fn main() -> anyhow::Result<()> {
|
||||||
match encoder_rx.try_recv() {
|
match encoder_rx.try_recv() {
|
||||||
Ok(msg) => match msg {
|
Ok(msg) => match msg {
|
||||||
encoder_thread::EncodeThreadOutput::Frame { packet } => {
|
encoder_thread::EncodeThreadOutput::Frame { packet } => {
|
||||||
let _ = state_clone
|
// let _ = state_clone
|
||||||
.websocket_broadcast_tx
|
// .websocket_broadcast_tx
|
||||||
.send(WsMessage::VideoPacket { packet });
|
// .send(WsMessage::VideoPacket { packet });
|
||||||
|
|
||||||
|
// :(
|
||||||
|
let packet_data = {
|
||||||
|
let slice = packet.data().expect(
|
||||||
|
"should NOT be empty, this invariant is checked beforehand",
|
||||||
|
);
|
||||||
|
slice.to_vec()
|
||||||
|
};
|
||||||
|
let _ = transport_clone.broadcast_message(transport::TransportMessage::Binary(packet_data));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(TryRecvError::Empty) => {}
|
Err(TryRecvError::Empty) => {}
|
||||||
|
@ -182,223 +313,6 @@ async fn main() -> anyhow::Result<()> {
|
||||||
})
|
})
|
||||||
.expect("failed to spawn retro RX thread, it's probably over");
|
.expect("failed to spawn retro RX thread, it's probably over");
|
||||||
|
|
||||||
// Axum websocket server
|
transport.listen(state).await?;
|
||||||
let app: Router<()> = Router::new()
|
|
||||||
.route(
|
|
||||||
"/",
|
|
||||||
get(
|
|
||||||
|ws: WebSocketUpgrade,
|
|
||||||
info: ConnectInfo<SocketAddr>,
|
|
||||||
state: State<Arc<AppState>>| async move {
|
|
||||||
ws_handler(ws, info, state).await
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.with_state(state.clone());
|
|
||||||
|
|
||||||
let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940")
|
|
||||||
.await
|
|
||||||
.expect("failed to listen");
|
|
||||||
|
|
||||||
let axum_future = axum::serve(
|
|
||||||
tcp_listener,
|
|
||||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// If the VNC client disconnects we should exit.
|
|
||||||
tokio::select! {
|
|
||||||
_ = axum_future => {
|
|
||||||
println!("axum died");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ws_handler(
|
|
||||||
ws: WebSocketUpgrade,
|
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
||||||
|
|
||||||
State(state): State<Arc<AppState>>,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
// finalize the upgrade process by returning upgrade callback.
|
|
||||||
// we can customize the callback by sending additional info such as address.
|
|
||||||
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Actual websocket statemachine (one will be spawned per connection)
|
|
||||||
async fn handle_socket(socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
|
|
||||||
let (mut sender, mut receiver) = socket.split();
|
|
||||||
|
|
||||||
// increment connection count
|
|
||||||
let inc_clone = Arc::clone(&state);
|
|
||||||
{
|
|
||||||
let mut lk = inc_clone.websocket_count.lock().await;
|
|
||||||
*lk += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let locked = state.encoder_tx.lock().await;
|
|
||||||
|
|
||||||
// Force a ws connection to mean a keyframe
|
|
||||||
let _ = locked.send(EncodeThreadInput::ForceKeyframe).await;
|
|
||||||
let _ = locked.send(EncodeThreadInput::SendFrame).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
// random username
|
|
||||||
let username: Arc<String> =
|
|
||||||
Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
|
|
||||||
|
|
||||||
println!("{username} ({who}) connected.");
|
|
||||||
|
|
||||||
let send_clone = Arc::clone(&state);
|
|
||||||
let mut send_task = tokio::spawn(async move {
|
|
||||||
let mut sub = send_clone.websocket_broadcast_tx.subscribe();
|
|
||||||
|
|
||||||
while let Ok(msg) = sub.recv().await {
|
|
||||||
match msg {
|
|
||||||
WsMessage::VideoPacket { mut packet } => {
|
|
||||||
// :(. At least this copy doesn't occur on the driver threads anymore..
|
|
||||||
let data = packet.data_mut().expect("shouldn't be taken");
|
|
||||||
let msg = ws::Message::Binary(data.to_vec());
|
|
||||||
if sender.send(msg).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
WsMessage::Json(s) => {
|
|
||||||
let msg = ws::Message::Text(s);
|
|
||||||
if sender.send(msg).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let username_clone = Arc::clone(&username);
|
|
||||||
|
|
||||||
let recv_clone = Arc::clone(&state);
|
|
||||||
|
|
||||||
let mut recv_task = tokio::spawn(async move {
|
|
||||||
while let Some(Ok(msg)) = receiver.next().await {
|
|
||||||
match msg {
|
|
||||||
Message::Text(msg) => {
|
|
||||||
// println!("{}", msg);
|
|
||||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&msg) {
|
|
||||||
if !json["type"].is_string() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
match json["type"].as_str().unwrap() {
|
|
||||||
"chat" => {
|
|
||||||
if !json["msg"].is_string() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let send = serde_json::json!({
|
|
||||||
"type": "chat",
|
|
||||||
"username": *username_clone,
|
|
||||||
"msg": json["msg"].as_str().unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
recv_clone.websocket_broadcast_tx.send(WsMessage::Json(
|
|
||||||
serde_json::to_string(&send).expect("oh well"),
|
|
||||||
));
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
"key" => {
|
|
||||||
if !json["keysym"].is_number() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !json["pressed"].is_number() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let keysym = json["keysym"].as_u64().unwrap() as u32;
|
|
||||||
let pressed = json["pressed"].as_u64().unwrap() == 1;
|
|
||||||
|
|
||||||
// FIXME: This would be MUCH better off being a set, so we don't
|
|
||||||
// hack-code set semantics here. Oh well.
|
|
||||||
{
|
|
||||||
let mut lock = recv_clone.inputs.lock().await;
|
|
||||||
if pressed {
|
|
||||||
if let None = lock.iter().position(|e| *e == keysym) {
|
|
||||||
lock.push(keysym);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if let Some(at) = lock.iter().position(|e| *e == keysym) {
|
|
||||||
lock.remove(at);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*let _ = recv_clone
|
|
||||||
.engine_tx
|
|
||||||
.send(vnc_engine::VncMessageInput::KeyEvent {
|
|
||||||
keysym: keysym,
|
|
||||||
pressed: pressed,
|
|
||||||
})
|
|
||||||
.await;*/
|
|
||||||
}
|
|
||||||
|
|
||||||
"mouse" => {
|
|
||||||
if json["x"].as_u64().is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if json["y"].as_u64().is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if json["mask"].as_u64().is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
//let x = json["x"].as_u64().unwrap() as u32;
|
|
||||||
//let y = json["y"].as_u64().unwrap() as u32;
|
|
||||||
//let mask = json["mask"].as_u64().unwrap() as u8;
|
|
||||||
|
|
||||||
/*let _ = recv_clone
|
|
||||||
.engine_tx
|
|
||||||
.send(vnc_engine::VncMessageInput::MouseEvent {
|
|
||||||
pt: types::Point { x: x, y: y },
|
|
||||||
buttons: mask,
|
|
||||||
})
|
|
||||||
.await;*/
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Message::Close(_) => break,
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
_ = (&mut send_task) => {
|
|
||||||
|
|
||||||
recv_task.abort();
|
|
||||||
},
|
|
||||||
|
|
||||||
_ = (&mut recv_task) => {
|
|
||||||
send_task.abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("{username} ({who}) left.");
|
|
||||||
|
|
||||||
let dec_clone = Arc::clone(&state);
|
|
||||||
{
|
|
||||||
let mut lk = dec_clone.websocket_count.lock().await;
|
|
||||||
*lk -= 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
29
server/src/transport/mod.rs
Normal file
29
server/src/transport/mod.rs
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
pub mod websocket;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum TransportMessage {
|
||||||
|
Text(String),
|
||||||
|
Binary(Vec<u8>)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Transport : Send {
|
||||||
|
async fn listen<T: TransportReciever + Send + Sync + 'static>(&self, iface: Arc<T>) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
/// Broadcasts a message.
|
||||||
|
fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait TransportReciever : Send {
|
||||||
|
async fn on_connect(&self, username: &String) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
async fn on_leave(&self, username: &String) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
}
|
154
server/src/transport/websocket.rs
Normal file
154
server/src/transport/websocket.rs
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
use super::{Transport, TransportMessage, TransportReciever};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use rand::distributions::DistString;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tokio::sync::{
|
||||||
|
broadcast,
|
||||||
|
mpsc::{self, error::TryRecvError},
|
||||||
|
Mutex as TokioMutex,
|
||||||
|
};
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{
|
||||||
|
connect_info::ConnectInfo,
|
||||||
|
ws::{self, Message, WebSocket, WebSocketUpgrade},
|
||||||
|
State,
|
||||||
|
},
|
||||||
|
response::IntoResponse,
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::{sink::SinkExt, stream::StreamExt};
|
||||||
|
|
||||||
|
pub struct WebsocketTransport {
|
||||||
|
broadcast_tx: Arc<broadcast::Sender<TransportMessage>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ws_handler<T: TransportReciever + Sync + Send + 'static>(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
|
|
||||||
|
State(state): State<Arc<T>>,
|
||||||
|
mut broadcast_rx: broadcast::Receiver<TransportMessage>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
// finalize the upgrade process by returning upgrade callback.
|
||||||
|
// we can customize the callback by sending additional info such as address.
|
||||||
|
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, broadcast_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Actual websocket statemachine (one will be spawned per connection)
|
||||||
|
async fn handle_socket<T: TransportReciever + Sync + Send + 'static>(
|
||||||
|
socket: WebSocket,
|
||||||
|
who: SocketAddr,
|
||||||
|
state: Arc<T>,
|
||||||
|
mut broadcast_rx: broadcast::Receiver<TransportMessage>,
|
||||||
|
) {
|
||||||
|
let (mut sender, mut receiver) = socket.split();
|
||||||
|
|
||||||
|
// random username
|
||||||
|
let username: Arc<String> =
|
||||||
|
Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
|
||||||
|
|
||||||
|
let recv_clone = Arc::clone(&state);
|
||||||
|
|
||||||
|
state.on_connect(&username).await;
|
||||||
|
|
||||||
|
let mut send_task = tokio::spawn(async move {
|
||||||
|
while let Ok(msg) = broadcast_rx.recv().await {
|
||||||
|
match msg {
|
||||||
|
TransportMessage::Text(b) => {
|
||||||
|
if sender.send(Message::Text(b)).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TransportMessage::Binary(b) => {
|
||||||
|
if sender.send(Message::Binary(b)).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let username_clone = username.clone();
|
||||||
|
let mut recv_task = tokio::spawn(async move {
|
||||||
|
while let Some(Ok(msg)) = receiver.next().await {
|
||||||
|
match msg {
|
||||||
|
Message::Text(msg) => {
|
||||||
|
recv_clone.on_message(&username_clone, &msg).await;
|
||||||
|
}
|
||||||
|
Message::Close(_) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = (&mut send_task) => {
|
||||||
|
recv_task.abort();
|
||||||
|
},
|
||||||
|
|
||||||
|
_ = (&mut recv_task) => {
|
||||||
|
send_task.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state.on_leave(&username).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebsocketTransport {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let (broadcast_tx, _) = broadcast::channel(32);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
broadcast_tx: Arc::new(broadcast_tx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Transport for WebsocketTransport {
|
||||||
|
async fn listen<T: super::TransportReciever + Send + Sync + 'static>(
|
||||||
|
&self,
|
||||||
|
iface: Arc<T>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let tx_clone = self.broadcast_tx.clone();
|
||||||
|
|
||||||
|
// Axum websocket server
|
||||||
|
let app: Router<()> =
|
||||||
|
Router::new()
|
||||||
|
.route(
|
||||||
|
"/",
|
||||||
|
get(
|
||||||
|
|ws: WebSocketUpgrade,
|
||||||
|
info: ConnectInfo<SocketAddr>,
|
||||||
|
state: State<Arc<T>>| async move {
|
||||||
|
ws_handler(ws, info, state, tx_clone.subscribe()).await
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_state(iface.clone());
|
||||||
|
|
||||||
|
let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940")
|
||||||
|
.await
|
||||||
|
.expect("failed to listen");
|
||||||
|
|
||||||
|
let axum_future = axum::serve(
|
||||||
|
tcp_listener,
|
||||||
|
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
axum_future.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()> {
|
||||||
|
let _ = self.broadcast_tx.send(m);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue