1 // #![deny(warnings)]
2 use std::collections::HashMap;
3 use std::sync::{
4     atomic::{AtomicUsize, Ordering},
5     Arc,
6 };
7 
8 use futures::{FutureExt, StreamExt};
9 use tokio::sync::{mpsc, RwLock};
10 use tokio_stream::wrappers::UnboundedReceiverStream;
11 use warp::ws::{Message, WebSocket};
12 use warp::Filter;
13 
14 /// Our global unique user id counter.
15 static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
16 
17 /// Our state of currently connected users.
18 ///
19 /// - Key is their id
20 /// - Value is a sender of `warp::ws::Message`
21 type Users = Arc<RwLock<HashMap<usize, mpsc::UnboundedSender<Result<Message, warp::Error>>>>>;
22 
23 #[tokio::main]
main()24 async fn main() {
25     pretty_env_logger::init();
26 
27     // Keep track of all connected users, key is usize, value
28     // is a websocket sender.
29     let users = Users::default();
30     // Turn our "state" into a new Filter...
31     let users = warp::any().map(move || users.clone());
32 
33     // GET /chat -> websocket upgrade
34     let chat = warp::path("chat")
35         // The `ws()` filter will prepare Websocket handshake...
36         .and(warp::ws())
37         .and(users)
38         .map(|ws: warp::ws::Ws, users| {
39             // This will call our function if the handshake succeeds.
40             ws.on_upgrade(move |socket| user_connected(socket, users))
41         });
42 
43     // GET / -> index html
44     let index = warp::path::end().map(|| warp::reply::html(INDEX_HTML));
45 
46     let routes = index.or(chat);
47 
48     warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
49 }
50 
user_connected(ws: WebSocket, users: Users)51 async fn user_connected(ws: WebSocket, users: Users) {
52     // Use a counter to assign a new unique ID for this user.
53     let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
54 
55     eprintln!("new chat user: {}", my_id);
56 
57     // Split the socket into a sender and receive of messages.
58     let (user_ws_tx, mut user_ws_rx) = ws.split();
59 
60     // Use an unbounded channel to handle buffering and flushing of messages
61     // to the websocket...
62     let (tx, rx) = mpsc::unbounded_channel();
63     let rx = UnboundedReceiverStream::new(rx);
64     tokio::task::spawn(rx.forward(user_ws_tx).map(|result| {
65         if let Err(e) = result {
66             eprintln!("websocket send error: {}", e);
67         }
68     }));
69 
70     // Save the sender in our list of connected users.
71     users.write().await.insert(my_id, tx);
72 
73     // Return a `Future` that is basically a state machine managing
74     // this specific user's connection.
75 
76     // Make an extra clone to give to our disconnection handler...
77     let users2 = users.clone();
78 
79     // Every time the user sends a message, broadcast it to
80     // all other users...
81     while let Some(result) = user_ws_rx.next().await {
82         let msg = match result {
83             Ok(msg) => msg,
84             Err(e) => {
85                 eprintln!("websocket error(uid={}): {}", my_id, e);
86                 break;
87             }
88         };
89         user_message(my_id, msg, &users).await;
90     }
91 
92     // user_ws_rx stream will keep processing as long as the user stays
93     // connected. Once they disconnect, then...
94     user_disconnected(my_id, &users2).await;
95 }
96 
user_message(my_id: usize, msg: Message, users: &Users)97 async fn user_message(my_id: usize, msg: Message, users: &Users) {
98     // Skip any non-Text messages...
99     let msg = if let Ok(s) = msg.to_str() {
100         s
101     } else {
102         return;
103     };
104 
105     let new_msg = format!("<User#{}>: {}", my_id, msg);
106 
107     // New message from this user, send it to everyone else (except same uid)...
108     for (&uid, tx) in users.read().await.iter() {
109         if my_id != uid {
110             if let Err(_disconnected) = tx.send(Ok(Message::text(new_msg.clone()))) {
111                 // The tx is disconnected, our `user_disconnected` code
112                 // should be happening in another task, nothing more to
113                 // do here.
114             }
115         }
116     }
117 }
118 
user_disconnected(my_id: usize, users: &Users)119 async fn user_disconnected(my_id: usize, users: &Users) {
120     eprintln!("good bye user: {}", my_id);
121 
122     // Stream closed up, so remove from the user list
123     users.write().await.remove(&my_id);
124 }
125 
126 static INDEX_HTML: &str = r#"<!DOCTYPE html>
127 <html lang="en">
128     <head>
129         <title>Warp Chat</title>
130     </head>
131     <body>
132         <h1>Warp chat</h1>
133         <div id="chat">
134             <p><em>Connecting...</em></p>
135         </div>
136         <input type="text" id="text" />
137         <button type="button" id="send">Send</button>
138         <script type="text/javascript">
139         const chat = document.getElementById('chat');
140         const text = document.getElementById('text');
141         const uri = 'ws://' + location.host + '/chat';
142         const ws = new WebSocket(uri);
143 
144         function message(data) {
145             const line = document.createElement('p');
146             line.innerText = data;
147             chat.appendChild(line);
148         }
149 
150         ws.onopen = function() {
151             chat.innerHTML = '<p><em>Connected!</em></p>';
152         };
153 
154         ws.onmessage = function(msg) {
155             message(msg.data);
156         };
157 
158         ws.onclose = function() {
159             chat.getElementsByTagName('em')[0].innerText = 'Disconnected!';
160         };
161 
162         send.onclick = function() {
163             const msg = text.value;
164             ws.send(msg);
165             text.value = '';
166 
167             message('<You>: ' + msg);
168         };
169         </script>
170     </body>
171 </html>
172 "#;
173