Last active
December 12, 2024 10:20
-
-
Save kotobukid/f43354062c80ef1a5062eea71362e61e to your computer and use it in GitHub Desktop.
WebSocketのブロードキャストをmpscで実現
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| [package] | |
| name = "ws_s" | |
| version = "0.1.0" | |
| edition = "2021" | |
| [[bin]] | |
| name = "server" | |
| path = "src/server.rs" | |
| [[bin]] | |
| name = "client" | |
| path = "src/client.rs" | |
| [dependencies] | |
| anyhow = "1.0.94" | |
| futures-util = "0.3.31" | |
| log = "0.4.22" | |
| tokio = { version = "1.42.0", features = ["full"] } | |
| tokio-tungstenite = "0.24.0" | |
| futures-channel = "0.3.31" | |
| uuid = { version = "1.11.0", features = ["v4"] } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use futures_util::{future, pin_mut, StreamExt}; | |
| use std::env; | |
| use tokio::io::{AsyncReadExt, AsyncWriteExt}; | |
| use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; | |
| #[tokio::main] | |
| async fn main() { | |
| // let url = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080); | |
| let url = env::args() | |
| .nth(1) | |
| .unwrap_or_else(|| panic!("this program requires at least one argument")); | |
| let (stdin_tx, stdin_rx) = futures_channel::mpsc::unbounded(); | |
| tokio::spawn(read_stdin(stdin_tx)); | |
| let (ws_stream, _) = connect_async(&url).await.expect("Failed to connect"); | |
| println!("WebSocket handshake has been successfully completed"); | |
| let (write, read) = ws_stream.split(); | |
| let stdin_to_ws = stdin_rx.map(Ok).forward(write); | |
| let ws_to_stdout = { | |
| read.for_each(|message| async { | |
| let data = message.unwrap().into_data(); | |
| println!("server says... {:?}", data); | |
| tokio::io::stdout().write_all(&data).await.unwrap(); | |
| }) | |
| }; | |
| pin_mut!(stdin_to_ws, ws_to_stdout); | |
| future::select(stdin_to_ws, ws_to_stdout).await; | |
| } | |
| async fn read_stdin(tx: futures_channel::mpsc::UnboundedSender<Message>) { | |
| let mut stdin = tokio::io::stdin(); | |
| loop { | |
| let mut buf = vec![0; 1024]; | |
| let n = match stdin.read(&mut buf).await { | |
| Err(_) | Ok(0) => break, | |
| Ok(n) => n, | |
| }; | |
| buf.truncate(n); | |
| tx.unbounded_send(Message::binary(buf)).unwrap(); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use futures_util::{SinkExt, StreamExt}; | |
| use log::info; | |
| use std::collections::HashMap; | |
| use std::net::SocketAddrV4; | |
| use std::sync::Arc; | |
| use tokio::net::{TcpListener, TcpStream}; | |
| use tokio::sync::mpsc::Sender; | |
| use tokio::sync::Mutex; | |
| use uuid::Uuid; | |
| struct SocketWrapper { | |
| id: Uuid, | |
| socket: Sender<String>, | |
| } | |
| struct SocketManager { | |
| sockets: Arc<Mutex<HashMap<Uuid, SocketWrapper>>>, | |
| } | |
| impl SocketManager { | |
| fn new() -> Self { | |
| Self { | |
| sockets: Arc::new(Mutex::new(HashMap::new())), | |
| } | |
| } | |
| async fn add(&mut self, socket: Sender<String>) -> Uuid { | |
| let id = Uuid::new_v4(); | |
| let socket = SocketWrapper { id, socket }; | |
| let mut sockets = self.sockets.lock().await; | |
| sockets.insert(id, socket); | |
| id | |
| } | |
| async fn remove(&mut self, id: Uuid) { | |
| let mut sockets = self.sockets.lock().await; | |
| if sockets.remove(&id).is_some() { | |
| println!("Socket with ID {} removed", id); | |
| } | |
| } | |
| async fn broadcast(&self, message: String) { | |
| let sockets = self.sockets.lock().await; | |
| for (_, socket_wrapper) in sockets.iter() { | |
| if let Err(err) = socket_wrapper.socket.send(message.clone()).await { | |
| eprintln!("Failed to send message to {}: {}", socket_wrapper.id, err); | |
| } | |
| } | |
| } | |
| async fn dump(&self) { | |
| let sockets = self.sockets.lock().await; // 非同期ロックを取得 | |
| println!("Current sockets:"); | |
| for (id, _sender) in sockets.iter() { | |
| println!("\t{}", id); | |
| } | |
| println!(); | |
| } | |
| } | |
| #[tokio::main] | |
| async fn main() -> anyhow::Result<()> { | |
| let addr: SocketAddrV4 = "127.0.0.1:8080".parse()?; | |
| let socket: std::io::Result<TcpListener> = TcpListener::bind(&addr).await; | |
| let listener: TcpListener = socket.expect("Failed to bind socket"); | |
| let socket_manager = Arc::new(Mutex::new(SocketManager::new())); | |
| println!("Listening on: {}", addr); | |
| while let Ok((stream, _)) = listener.accept().await { | |
| let socket_manager = socket_manager.clone(); | |
| tokio::spawn(async move { | |
| accept_connection(socket_manager, stream).await; | |
| }); | |
| } | |
| Ok(()) | |
| } | |
| async fn accept_connection(manager: Arc<Mutex<SocketManager>>, stream: TcpStream) { | |
| let addr = stream | |
| .peer_addr() | |
| .expect("connected streams should have a peer address"); | |
| info!("Peer address: {}", addr); | |
| let ws_stream = tokio_tungstenite::accept_async(stream) | |
| .await | |
| .expect("Error during the websocket handshake occurred"); | |
| info!("New WebSocket connection: {}", addr); | |
| let (mut write, mut read) = ws_stream.split(); | |
| let (tx, mut rx) = tokio::sync::mpsc::channel(1000); | |
| let uuid = { | |
| let mut manager = manager.lock().await; | |
| manager.add(tx.clone()).await | |
| }; | |
| let manager_clone_1 = manager.clone(); | |
| let manager_clone_2 = manager.clone(); | |
| // For each incoming message, log the content to the standard output | |
| tokio::spawn(async move { | |
| println!("ws receive thread start."); | |
| while let Some(Ok(msg)) = read.next().await { | |
| if msg.is_text() || msg.is_binary() { | |
| let message_string = msg.to_string().trim().to_string(); // 安全に加工 | |
| println!("received: {}", message_string); | |
| // 受け取ったメッセージを全クライアントにブロードキャスト | |
| let manager = manager_clone_1.lock().await; // ロックを取得 | |
| manager.broadcast(message_string).await; | |
| } | |
| } | |
| // 削除処理 | |
| let mut manager = manager_clone_2.lock().await; | |
| manager.remove(uuid).await; // 該当するUUIDを削除 | |
| println!("ws receive thread end."); | |
| }); | |
| let _ = tokio::spawn(async move { | |
| println!("echo thread start."); | |
| while let Some(m) = rx.recv().await { | |
| if let Err(e) = write.send(m.into()).await { | |
| eprintln!("Error sending to WebSocket: {}", e); | |
| break; | |
| } | |
| } | |
| println!("echo thread end.") | |
| }); | |
| manager.lock().await.dump().await; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment