1 /*
2  * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 import java.nio.ByteBuffer;
25 
26 /* Copied from jdk.internal.net.http.websocket.Frame */
27 final class Frame {
28 
29     final Opcode opcode;
30     final ByteBuffer data;
31     final boolean last;
32 
Frame(Opcode opcode, ByteBuffer data, boolean last)33     public Frame(Opcode opcode, ByteBuffer data, boolean last) {
34         this.opcode = opcode;
35         /* copy */
36         this.data = ByteBuffer.allocate(data.remaining()).put(data.slice()).flip();
37         this.last = last;
38     }
39 
40     static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4;
41     static final int MAX_CONTROL_FRAME_PAYLOAD_SIZE = 125;
42 
43     enum Opcode {
44 
45         CONTINUATION   (0x0),
46         TEXT           (0x1),
47         BINARY         (0x2),
48         NON_CONTROL_0x3(0x3),
49         NON_CONTROL_0x4(0x4),
50         NON_CONTROL_0x5(0x5),
51         NON_CONTROL_0x6(0x6),
52         NON_CONTROL_0x7(0x7),
53         CLOSE          (0x8),
54         PING           (0x9),
55         PONG           (0xA),
56         CONTROL_0xB    (0xB),
57         CONTROL_0xC    (0xC),
58         CONTROL_0xD    (0xD),
59         CONTROL_0xE    (0xE),
60         CONTROL_0xF    (0xF);
61 
62         private static final Opcode[] opcodes;
63 
64         static {
65             Opcode[] values = values();
66             opcodes = new Opcode[values.length];
67             for (Opcode c : values) {
68                 opcodes[c.code] = c;
69             }
70         }
71 
72         private final byte code;
73 
Opcode(int code)74         Opcode(int code) {
75             this.code = (byte) code;
76         }
77 
isControl()78         boolean isControl() {
79             return (code & 0x8) != 0;
80         }
81 
ofCode(int code)82         static Opcode ofCode(int code) {
83             return opcodes[code & 0xF];
84         }
85     }
86 
87     /*
88      * A utility for masking frame payload data.
89      */
90     static final class Masker {
91 
92         // Exploiting ByteBuffer's ability to read/write multi-byte integers
93         private final ByteBuffer acc = ByteBuffer.allocate(8);
94         private final int[] maskBytes = new int[4];
95         private int offset;
96         private long maskLong;
97 
98         /*
99          * Reads all remaining bytes from the given input buffer, masks them
100          * with the supplied mask and writes the resulting bytes to the given
101          * output buffer.
102          *
103          * The source and the destination buffers may be the same instance.
104          */
transferMasking(ByteBuffer src, ByteBuffer dst, int mask)105         static void transferMasking(ByteBuffer src, ByteBuffer dst, int mask) {
106             if (src.remaining() > dst.remaining()) {
107                 throw new IllegalArgumentException();
108             }
109             new Masker().mask(mask).transferMasking(src, dst);
110         }
111 
112         /*
113          * Clears this instance's state and sets the mask.
114          *
115          * The behaviour is as if the mask was set on a newly created instance.
116          */
mask(int value)117         Masker mask(int value) {
118             acc.clear().putInt(value).putInt(value).flip();
119             for (int i = 0; i < maskBytes.length; i++) {
120                 maskBytes[i] = acc.get(i);
121             }
122             offset = 0;
123             maskLong = acc.getLong(0);
124             return this;
125         }
126 
127         /*
128          * Reads as many remaining bytes as possible from the given input
129          * buffer, masks them with the previously set mask and writes the
130          * resulting bytes to the given output buffer.
131          *
132          * The source and the destination buffers may be the same instance. If
133          * the mask hasn't been previously set it is assumed to be 0.
134          */
transferMasking(ByteBuffer src, ByteBuffer dst)135         Masker transferMasking(ByteBuffer src, ByteBuffer dst) {
136             begin(src, dst);
137             loop(src, dst);
138             end(src, dst);
139             return this;
140         }
141 
142         /*
143          * Applies up to 3 remaining from the previous pass bytes of the mask.
144          */
begin(ByteBuffer src, ByteBuffer dst)145         private void begin(ByteBuffer src, ByteBuffer dst) {
146             if (offset == 0) { // No partially applied mask from the previous invocation
147                 return;
148             }
149             int i = src.position(), j = dst.position();
150             final int srcLim = src.limit(), dstLim = dst.limit();
151             for (; offset < 4 && i < srcLim && j < dstLim; i++, j++, offset++)
152             {
153                 dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
154             }
155             offset &= 3; // Will become 0 if the mask has been fully applied
156             src.position(i);
157             dst.position(j);
158         }
159 
160         /*
161          * Gallops one long (mask + mask) at a time.
162          */
loop(ByteBuffer src, ByteBuffer dst)163         private void loop(ByteBuffer src, ByteBuffer dst) {
164             int i = src.position();
165             int j = dst.position();
166             final int srcLongLim = src.limit() - 7, dstLongLim = dst.limit() - 7;
167             for (; i < srcLongLim && j < dstLongLim; i += 8, j += 8) {
168                 dst.putLong(j, src.getLong(i) ^ maskLong);
169             }
170             if (i > src.limit()) {
171                 src.position(i - 8);
172             } else {
173                 src.position(i);
174             }
175             if (j > dst.limit()) {
176                 dst.position(j - 8);
177             } else {
178                 dst.position(j);
179             }
180         }
181 
182         /*
183          * Applies up to 7 remaining from the "galloping" phase bytes of the
184          * mask.
185          */
end(ByteBuffer src, ByteBuffer dst)186         private void end(ByteBuffer src, ByteBuffer dst) {
187             assert Math.min(src.remaining(), dst.remaining()) < 8;
188             final int srcLim = src.limit(), dstLim = dst.limit();
189             int i = src.position(), j = dst.position();
190             for (; i < srcLim && j < dstLim;
191                  i++, j++, offset = (offset + 1) & 3) // offset cycles through 0..3
192             {
193                 dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
194             }
195             src.position(i);
196             dst.position(j);
197         }
198     }
199 
200     /*
201      * A builder-style writer of frame headers.
202      *
203      * The writer does not enforce any protocol-level rules, it simply writes a
204      * header structure to the given buffer. The order of calls to intermediate
205      * methods is NOT significant.
206      */
207     static final class HeaderWriter {
208 
209         private char firstChar;
210         private long payloadLen;
211         private int maskingKey;
212         private boolean mask;
213 
214         HeaderWriter fin(boolean value) {
215             if (value) {
216                 firstChar |=  0b10000000_00000000;
217             } else {
218                 firstChar &= ~0b10000000_00000000;
219             }
220             return this;
221         }
222 
223         HeaderWriter rsv1(boolean value) {
224             if (value) {
225                 firstChar |=  0b01000000_00000000;
226             } else {
227                 firstChar &= ~0b01000000_00000000;
228             }
229             return this;
230         }
231 
232         HeaderWriter rsv2(boolean value) {
233             if (value) {
234                 firstChar |=  0b00100000_00000000;
235             } else {
236                 firstChar &= ~0b00100000_00000000;
237             }
238             return this;
239         }
240 
241         HeaderWriter rsv3(boolean value) {
242             if (value) {
243                 firstChar |=  0b00010000_00000000;
244             } else {
245                 firstChar &= ~0b00010000_00000000;
246             }
247             return this;
248         }
249 
250         HeaderWriter opcode(Opcode value) {
251             firstChar = (char) ((firstChar & 0xF0FF) | (value.code << 8));
252             return this;
253         }
254 
255         HeaderWriter payloadLen(long value) {
256             if (value < 0) {
257                 throw new IllegalArgumentException("Negative: " + value);
258             }
259             payloadLen = value;
260             firstChar &= 0b11111111_10000000; // Clear previous payload length leftovers
261             if (payloadLen < 126) {
262                 firstChar |= payloadLen;
263             } else if (payloadLen < 65536) {
264                 firstChar |= 126;
265             } else {
266                 firstChar |= 127;
267             }
268             return this;
269         }
270 
271         HeaderWriter mask(int value) {
272             firstChar |= 0b00000000_10000000;
273             maskingKey = value;
274             mask = true;
275             return this;
276         }
277 
278         HeaderWriter noMask() {
279             firstChar &= ~0b00000000_10000000;
280             mask = false;
281             return this;
282         }
283 
284         /*
285          * Writes the header to the given buffer.
286          *
287          * The buffer must have at least MAX_HEADER_SIZE_BYTES remaining. The
288          * buffer's position is incremented by the number of bytes written.
289          */
290         void write(ByteBuffer buffer) {
291             buffer.putChar(firstChar);
292             if (payloadLen >= 126) {
293                 if (payloadLen < 65536) {
294                     buffer.putChar((char) payloadLen);
295                 } else {
296                     buffer.putLong(payloadLen);
297                 }
298             }
299             if (mask) {
300                 buffer.putInt(maskingKey);
301             }
302         }
303     }
304 
305     /*
306      * A consumer of frame parts.
307      *
308      * Frame.Reader invokes the consumer's methods in the following order:
309      *
310      *     fin rsv1 rsv2 rsv3 opcode mask payloadLength maskingKey? payloadData+ endFrame
311      */
312     interface Consumer {
313 
314         void fin(boolean value);
315 
316         void rsv1(boolean value);
317 
318         void rsv2(boolean value);
319 
320         void rsv3(boolean value);
321 
322         void opcode(Opcode value);
323 
324         void mask(boolean value);
325 
326         void payloadLen(long value);
327 
328         void maskingKey(int value);
329 
330         /*
331          * Called by the Frame.Reader when a part of the (or a complete) payload
332          * is ready to be consumed.
333          *
334          * The sum of numbers of bytes consumed in each invocation of this
335          * method corresponding to the given frame WILL be equal to
336          * 'payloadLen', reported to `void payloadLen(long value)` before that.
337          *
338          * In particular, if `payloadLen` is 0, then there WILL be a single
339          * invocation to this method.
340          *
341          * No unmasking is done.
342          */
343         void payloadData(ByteBuffer data);
344 
345         void endFrame();
346     }
347 
348     /*
349      * A Reader of frames.
350      *
351      * No protocol-level rules are checked.
352      */
353     static final class Reader {
354 
355         private static final int AWAITING_FIRST_BYTE  =  1;
356         private static final int AWAITING_SECOND_BYTE =  2;
357         private static final int READING_16_LENGTH    =  4;
358         private static final int READING_64_LENGTH    =  8;
359         private static final int READING_MASK         = 16;
360         private static final int READING_PAYLOAD      = 32;
361 
362         // Exploiting ByteBuffer's ability to read multi-byte integers
363         private final ByteBuffer accumulator = ByteBuffer.allocate(8);
364         private int state = AWAITING_FIRST_BYTE;
365         private boolean mask;
366         private long remainingPayloadLength;
367 
368         /*
369          * Reads at most one frame from the given buffer invoking the consumer's
370          * methods corresponding to the frame parts found.
371          *
372          * As much of the frame's payload, if any, is read. The buffer's
373          * position is updated to reflect the number of bytes read.
374          *
375          * Throws FailWebSocketException if detects the frame is malformed.
376          */
377         void readFrame(ByteBuffer input, Consumer consumer) {
378             loop:
379             while (true) {
380                 byte b;
381                 switch (state) {
382                     case AWAITING_FIRST_BYTE:
383                         if (!input.hasRemaining()) {
384                             break loop;
385                         }
386                         b = input.get();
387                         consumer.fin( (b & 0b10000000) != 0);
388                         consumer.rsv1((b & 0b01000000) != 0);
389                         consumer.rsv2((b & 0b00100000) != 0);
390                         consumer.rsv3((b & 0b00010000) != 0);
391                         consumer.opcode(Opcode.ofCode(b));
392                         state = AWAITING_SECOND_BYTE;
393                         continue loop;
394                     case AWAITING_SECOND_BYTE:
395                         if (!input.hasRemaining()) {
396                             break loop;
397                         }
398                         b = input.get();
399                         consumer.mask(mask = (b & 0b10000000) != 0);
400                         byte p1 = (byte) (b & 0b01111111);
401                         if (p1 < 126) {
402                             assert p1 >= 0 : p1;
403                             consumer.payloadLen(remainingPayloadLength = p1);
404                             state = mask ? READING_MASK : READING_PAYLOAD;
405                         } else if (p1 < 127) {
406                             state = READING_16_LENGTH;
407                         } else {
408                             state = READING_64_LENGTH;
409                         }
410                         continue loop;
411                     case READING_16_LENGTH:
412                         if (!input.hasRemaining()) {
413                             break loop;
414                         }
415                         b = input.get();
416                         if (accumulator.put(b).position() < 2) {
417                             continue loop;
418                         }
419                         remainingPayloadLength = accumulator.flip().getChar();
420                         if (remainingPayloadLength < 126) {
421                             throw notMinimalEncoding(remainingPayloadLength);
422                         }
423                         consumer.payloadLen(remainingPayloadLength);
424                         accumulator.clear();
425                         state = mask ? READING_MASK : READING_PAYLOAD;
426                         continue loop;
427                     case READING_64_LENGTH:
428                         if (!input.hasRemaining()) {
429                             break loop;
430                         }
431                         b = input.get();
432                         if (accumulator.put(b).position() < 8) {
433                             continue loop;
434                         }
435                         remainingPayloadLength = accumulator.flip().getLong();
436                         if (remainingPayloadLength < 0) {
437                             throw negativePayload(remainingPayloadLength);
438                         } else if (remainingPayloadLength < 65536) {
439                             throw notMinimalEncoding(remainingPayloadLength);
440                         }
441                         consumer.payloadLen(remainingPayloadLength);
442                         accumulator.clear();
443                         state = mask ? READING_MASK : READING_PAYLOAD;
444                         continue loop;
445                     case READING_MASK:
446                         if (!input.hasRemaining()) {
447                             break loop;
448                         }
449                         b = input.get();
450                         if (accumulator.put(b).position() != 4) {
451                             continue loop;
452                         }
453                         consumer.maskingKey(accumulator.flip().getInt());
454                         accumulator.clear();
455                         state = READING_PAYLOAD;
456                         continue loop;
457                     case READING_PAYLOAD:
458                         // This state does not require any bytes to be available
459                         // in the input buffer in order to proceed
460                         int deliverable = (int) Math.min(remainingPayloadLength,
461                                                          input.remaining());
462                         int oldLimit = input.limit();
463                         input.limit(input.position() + deliverable);
464                         if (deliverable != 0 || remainingPayloadLength == 0) {
465                             consumer.payloadData(input);
466                         }
467                         int consumed = deliverable - input.remaining();
468                         if (consumed < 0) {
469                             // Consumer cannot consume more than there was available
470                             throw new InternalError();
471                         }
472                         input.limit(oldLimit);
473                         remainingPayloadLength -= consumed;
474                         if (remainingPayloadLength == 0) {
475                             consumer.endFrame();
476                             state = AWAITING_FIRST_BYTE;
477                         }
478                         break loop;
479                     default:
480                         throw new InternalError(String.valueOf(state));
481                 }
482             }
483         }
484 
485         private static IllegalArgumentException negativePayload(long payloadLength)
486         {
487             return new IllegalArgumentException("Negative payload length: "
488                                                         + payloadLength);
489         }
490 
491         private static IllegalArgumentException notMinimalEncoding(long payloadLength)
492         {
493             return new IllegalArgumentException("Not minimally-encoded payload length:"
494                                                       + payloadLength);
495         }
496     }
497 }
498