1 // #![deny(warnings)]
2 use std::collections::HashMap;
3 use std::sync::{
4     atomic::{AtomicUsize, Ordering},
5     Arc,
6 };
7 
8 use futures::{FutureExt, StreamExt};
9 use tokio::sync::{mpsc, Mutex};
10 use warp::ws::{Message, WebSocket};
11 use warp::Filter;
12 
13 /// Our global unique user id counter.
14 static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
15 
16 /// Our state of currently connected users.
17 ///
18 /// - Key is their id
19 /// - Value is a sender of `warp::ws::Message`
20 type Users = Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<Result<Message, warp::Error>>>>>;
21 
22 #[tokio::main]
main()23 async fn main() {
24     pretty_env_logger::init();
25 
26     // Keep track of all connected users, key is usize, value
27     // is a websocket sender.
28     let users = Arc::new(Mutex::new(HashMap::new()));
29     // Turn our "state" into a new Filter...
30     let users = warp::any().map(move || users.clone());
31 
32     // GET /chat -> websocket upgrade
33     let chat = warp::path("chat")
34         // The `ws()` filter will prepare Websocket handshake...
35         .and(warp::ws())
36         .and(users)
37         .map(|ws: warp::ws::Ws, users| {
38             // This will call our function if the handshake succeeds.
39             ws.on_upgrade(move |socket| user_connected(socket, users))
40         });
41 
42     // GET / -> index html
43     let index = warp::path::end().map(|| warp::reply::html(INDEX_HTML));
44 
45     let routes = index.or(chat);
46 
47     warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
48 }
49 
user_connected(ws: WebSocket, users: Users)50 async fn user_connected(ws: WebSocket, users: Users) {
51     // Use a counter to assign a new unique ID for this user.
52     let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
53 
54     eprintln!("new chat user: {}", my_id);
55 
56     // Split the socket into a sender and receive of messages.
57     let (user_ws_tx, mut user_ws_rx) = ws.split();
58 
59     // Use an unbounded channel to handle buffering and flushing of messages
60     // to the websocket...
61     let (tx, rx) = mpsc::unbounded_channel();
62     tokio::task::spawn(rx.forward(user_ws_tx).map(|result| {
63         if let Err(e) = result {
64             eprintln!("websocket send error: {}", e);
65         }
66     }));
67 
68     // Save the sender in our list of connected users.
69     users.lock().await.insert(my_id, tx);
70 
71     // Return a `Future` that is basically a state machine managing
72     // this specific user's connection.
73 
74     // Make an extra clone to give to our disconnection handler...
75     let users2 = users.clone();
76 
77     // Every time the user sends a message, broadcast it to
78     // all other users...
79     while let Some(result) = user_ws_rx.next().await {
80         let msg = match result {
81             Ok(msg) => msg,
82             Err(e) => {
83                 eprintln!("websocket error(uid={}): {}", my_id, e);
84                 break;
85             }
86         };
87         user_message(my_id, msg, &users).await;
88     }
89 
90     // user_ws_rx stream will keep processing as long as the user stays
91     // connected. Once they disconnect, then...
92     user_disconnected(my_id, &users2).await;
93 }
94 
user_message(my_id: usize, msg: Message, users: &Users)95 async fn user_message(my_id: usize, msg: Message, users: &Users) {
96     // Skip any non-Text messages...
97     let msg = if let Ok(s) = msg.to_str() {
98         s
99     } else {
100         return;
101     };
102 
103     let new_msg = format!("<User#{}>: {}", my_id, msg);
104 
105     // New message from this user, send it to everyone else (except same uid)...
106     //
107     // We use `retain` instead of a for loop so that we can reap any user that
108     // appears to have disconnected.
109     for (&uid, tx) in users.lock().await.iter_mut() {
110         if my_id != uid {
111             if let Err(_disconnected) = tx.send(Ok(Message::text(new_msg.clone()))) {
112                 // The tx is disconnected, our `user_disconnected` code
113                 // should be happening in another task, nothing more to
114                 // do here.
115             }
116         }
117     }
118 }
119 
user_disconnected(my_id: usize, users: &Users)120 async fn user_disconnected(my_id: usize, users: &Users) {
121     eprintln!("good bye user: {}", my_id);
122 
123     // Stream closed up, so remove from the user list
124     users.lock().await.remove(&my_id);
125 }
126 
127 static INDEX_HTML: &str = r#"
128 <!DOCTYPE html>
129 <html>
130     <head>
131         <title>Warp Chat</title>
132     </head>
133     <body>
134         <h1>warp chat</h1>
135         <div id="chat">
136             <p><em>Connecting...</em></p>
137         </div>
138         <input type="text" id="text" />
139         <button type="button" id="send">Send</button>
140         <script type="text/javascript">
141         var uri = 'ws://' + location.host + '/chat';
142         var ws = new WebSocket(uri);
143 
144         function message(data) {
145             var line = document.createElement('p');
146             line.innerText = data;
147             chat.appendChild(line);
148         }
149 
150         ws.onopen = function() {
151             chat.innerHTML = "<p><em>Connected!</em></p>";
152         }
153 
154         ws.onmessage = function(msg) {
155             message(msg.data);
156         };
157 
158         send.onclick = function() {
159             var msg = text.value;
160             ws.send(msg);
161             text.value = '';
162 
163             message('<You>: ' + msg);
164         };
165         </script>
166     </body>
167 </html>
168 "#;
169