Skip to content

Instantly share code, notes, and snippets.

@kotobukid
Last active December 12, 2024 10:20
Show Gist options
  • Select an option

  • Save kotobukid/f43354062c80ef1a5062eea71362e61e to your computer and use it in GitHub Desktop.

Select an option

Save kotobukid/f43354062c80ef1a5062eea71362e61e to your computer and use it in GitHub Desktop.
WebSocketのブロードキャストをmpscで実現
[package]
name = "ws_s"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "server"
path = "src/server.rs"
[[bin]]
name = "client"
path = "src/client.rs"
[dependencies]
anyhow = "1.0.94"
futures-util = "0.3.31"
log = "0.4.22"
tokio = { version = "1.42.0", features = ["full"] }
tokio-tungstenite = "0.24.0"
futures-channel = "0.3.31"
uuid = { version = "1.11.0", features = ["v4"] }
use futures_util::{future, pin_mut, StreamExt};
use std::env;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
#[tokio::main]
async fn main() {
// let url = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080);
let url = env::args()
.nth(1)
.unwrap_or_else(|| panic!("this program requires at least one argument"));
let (stdin_tx, stdin_rx) = futures_channel::mpsc::unbounded();
tokio::spawn(read_stdin(stdin_tx));
let (ws_stream, _) = connect_async(&url).await.expect("Failed to connect");
println!("WebSocket handshake has been successfully completed");
let (write, read) = ws_stream.split();
let stdin_to_ws = stdin_rx.map(Ok).forward(write);
let ws_to_stdout = {
read.for_each(|message| async {
let data = message.unwrap().into_data();
println!("server says... {:?}", data);
tokio::io::stdout().write_all(&data).await.unwrap();
})
};
pin_mut!(stdin_to_ws, ws_to_stdout);
future::select(stdin_to_ws, ws_to_stdout).await;
}
async fn read_stdin(tx: futures_channel::mpsc::UnboundedSender<Message>) {
let mut stdin = tokio::io::stdin();
loop {
let mut buf = vec![0; 1024];
let n = match stdin.read(&mut buf).await {
Err(_) | Ok(0) => break,
Ok(n) => n,
};
buf.truncate(n);
tx.unbounded_send(Message::binary(buf)).unwrap();
}
}
use futures_util::{SinkExt, StreamExt};
use log::info;
use std::collections::HashMap;
use std::net::SocketAddrV4;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use uuid::Uuid;
struct SocketWrapper {
id: Uuid,
socket: Sender<String>,
}
struct SocketManager {
sockets: Arc<Mutex<HashMap<Uuid, SocketWrapper>>>,
}
impl SocketManager {
fn new() -> Self {
Self {
sockets: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn add(&mut self, socket: Sender<String>) -> Uuid {
let id = Uuid::new_v4();
let socket = SocketWrapper { id, socket };
let mut sockets = self.sockets.lock().await;
sockets.insert(id, socket);
id
}
async fn remove(&mut self, id: Uuid) {
let mut sockets = self.sockets.lock().await;
if sockets.remove(&id).is_some() {
println!("Socket with ID {} removed", id);
}
}
async fn broadcast(&self, message: String) {
let sockets = self.sockets.lock().await;
for (_, socket_wrapper) in sockets.iter() {
if let Err(err) = socket_wrapper.socket.send(message.clone()).await {
eprintln!("Failed to send message to {}: {}", socket_wrapper.id, err);
}
}
}
async fn dump(&self) {
let sockets = self.sockets.lock().await; // 非同期ロックを取得
println!("Current sockets:");
for (id, _sender) in sockets.iter() {
println!("\t{}", id);
}
println!();
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let addr: SocketAddrV4 = "127.0.0.1:8080".parse()?;
let socket: std::io::Result<TcpListener> = TcpListener::bind(&addr).await;
let listener: TcpListener = socket.expect("Failed to bind socket");
let socket_manager = Arc::new(Mutex::new(SocketManager::new()));
println!("Listening on: {}", addr);
while let Ok((stream, _)) = listener.accept().await {
let socket_manager = socket_manager.clone();
tokio::spawn(async move {
accept_connection(socket_manager, stream).await;
});
}
Ok(())
}
async fn accept_connection(manager: Arc<Mutex<SocketManager>>, stream: TcpStream) {
let addr = stream
.peer_addr()
.expect("connected streams should have a peer address");
info!("Peer address: {}", addr);
let ws_stream = tokio_tungstenite::accept_async(stream)
.await
.expect("Error during the websocket handshake occurred");
info!("New WebSocket connection: {}", addr);
let (mut write, mut read) = ws_stream.split();
let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
let uuid = {
let mut manager = manager.lock().await;
manager.add(tx.clone()).await
};
let manager_clone_1 = manager.clone();
let manager_clone_2 = manager.clone();
// For each incoming message, log the content to the standard output
tokio::spawn(async move {
println!("ws receive thread start.");
while let Some(Ok(msg)) = read.next().await {
if msg.is_text() || msg.is_binary() {
let message_string = msg.to_string().trim().to_string(); // 安全に加工
println!("received: {}", message_string);
// 受け取ったメッセージを全クライアントにブロードキャスト
let manager = manager_clone_1.lock().await; // ロックを取得
manager.broadcast(message_string).await;
}
}
// 削除処理
let mut manager = manager_clone_2.lock().await;
manager.remove(uuid).await; // 該当するUUIDを削除
println!("ws receive thread end.");
});
let _ = tokio::spawn(async move {
println!("echo thread start.");
while let Some(m) = rx.recv().await {
if let Err(e) = write.send(m.into()).await {
eprintln!("Error sending to WebSocket: {}", e);
break;
}
}
println!("echo thread end.")
});
manager.lock().await.dump().await;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment