1 /**
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  * SPDX-License-Identifier: Apache-2.0.
4  */
5 #include <aws/event-stream/event_stream_rpc.h>
6 
7 #include <inttypes.h>
8 
9 const struct aws_byte_cursor aws_event_stream_rpc_message_type_name =
10     AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(":message-type");
11 const struct aws_byte_cursor aws_event_stream_rpc_message_flags_name =
12     AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(":message-flags");
13 const struct aws_byte_cursor aws_event_stream_rpc_stream_id_name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL(":stream-id");
14 const struct aws_byte_cursor aws_event_stream_rpc_operation_name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("operation");
15 
16 /* just a convenience function for fetching message metadata from the event stream headers on a single iteration. */
aws_event_stream_rpc_extract_message_metadata(const struct aws_array_list * message_headers,int32_t * stream_id,int32_t * message_type,int32_t * message_flags,struct aws_byte_buf * operation_name)17 int aws_event_stream_rpc_extract_message_metadata(
18     const struct aws_array_list *message_headers,
19     int32_t *stream_id,
20     int32_t *message_type,
21     int32_t *message_flags,
22     struct aws_byte_buf *operation_name) {
23     size_t length = aws_array_list_length(message_headers);
24     bool message_type_found = 0;
25     bool message_flags_found = 0;
26     bool stream_id_found = 0;
27     bool operation_name_found = 0;
28 
29     AWS_LOGF_TRACE(
30         AWS_LS_EVENT_STREAM_GENERAL, "processing message headers for rpc protocol. %zu headers to process.", length);
31 
32     for (size_t i = 0; i < length; ++i) {
33         struct aws_event_stream_header_value_pair *header = NULL;
34         aws_array_list_get_at_ptr(message_headers, (void **)&header, i);
35         struct aws_byte_buf name_buf = aws_event_stream_header_name(header);
36         AWS_LOGF_DEBUG(AWS_LS_EVENT_STREAM_GENERAL, "processing header name " PRInSTR, AWS_BYTE_BUF_PRI(name_buf));
37 
38         /* check type first since that's cheaper than a string compare */
39         if (header->header_value_type == AWS_EVENT_STREAM_HEADER_INT32) {
40 
41             struct aws_byte_buf stream_id_field = aws_byte_buf_from_array(
42                 aws_event_stream_rpc_stream_id_name.ptr, aws_event_stream_rpc_stream_id_name.len);
43             if (aws_byte_buf_eq_ignore_case(&name_buf, &stream_id_field)) {
44                 *stream_id = aws_event_stream_header_value_as_int32(header);
45                 AWS_LOGF_DEBUG(AWS_LS_EVENT_STREAM_GENERAL, "stream id header value %" PRId32, *stream_id);
46                 stream_id_found += 1;
47                 goto found;
48             }
49 
50             struct aws_byte_buf message_type_field = aws_byte_buf_from_array(
51                 aws_event_stream_rpc_message_type_name.ptr, aws_event_stream_rpc_message_type_name.len);
52             if (aws_byte_buf_eq_ignore_case(&name_buf, &message_type_field)) {
53                 *message_type = aws_event_stream_header_value_as_int32(header);
54                 AWS_LOGF_DEBUG(AWS_LS_EVENT_STREAM_GENERAL, "message type header value %" PRId32, *message_type);
55                 message_type_found += 1;
56                 goto found;
57             }
58 
59             struct aws_byte_buf message_flags_field = aws_byte_buf_from_array(
60                 aws_event_stream_rpc_message_flags_name.ptr, aws_event_stream_rpc_message_flags_name.len);
61             if (aws_byte_buf_eq_ignore_case(&name_buf, &message_flags_field)) {
62                 *message_flags = aws_event_stream_header_value_as_int32(header);
63                 AWS_LOGF_DEBUG(AWS_LS_EVENT_STREAM_GENERAL, "message flags header value %" PRId32, *message_flags);
64                 message_flags_found += 1;
65                 goto found;
66             }
67         }
68 
69         if (header->header_value_type == AWS_EVENT_STREAM_HEADER_STRING) {
70             struct aws_byte_buf operation_field = aws_byte_buf_from_array(
71                 aws_event_stream_rpc_operation_name.ptr, aws_event_stream_rpc_operation_name.len);
72 
73             if (aws_byte_buf_eq_ignore_case(&name_buf, &operation_field)) {
74                 *operation_name = aws_event_stream_header_value_as_string(header);
75                 AWS_LOGF_DEBUG(
76                     AWS_LS_EVENT_STREAM_GENERAL,
77                     "operation name header value" PRInSTR,
78                     AWS_BYTE_BUF_PRI(*operation_name));
79                 operation_name_found += 1;
80                 goto found;
81             }
82         }
83 
84         continue;
85 
86     found:
87         if (message_flags_found && message_type_found && stream_id_found && operation_name_found) {
88             return AWS_OP_SUCCESS;
89         }
90     }
91 
92     return message_flags_found && message_type_found && stream_id_found
93                ? AWS_OP_SUCCESS
94                : aws_raise_error(AWS_ERROR_EVENT_STREAM_RPC_PROTOCOL_ERROR);
95 }
96 
97 static const uint32_t s_bit_scrambling_magic = 0x45d9f3bU;
98 static const uint32_t s_bit_shift_magic = 16U;
99 
100 /* this is a repurposed hash function based on the technique in splitmix64. The magic number was a result of numerical
101  * analysis on maximum bit entropy. */
aws_event_stream_rpc_hash_streamid(const void * to_hash)102 uint64_t aws_event_stream_rpc_hash_streamid(const void *to_hash) {
103     uint32_t int_to_hash = *(const uint32_t *)to_hash;
104     uint32_t hash = ((int_to_hash >> s_bit_shift_magic) ^ int_to_hash) * s_bit_scrambling_magic;
105     hash = ((hash >> s_bit_shift_magic) ^ hash) * s_bit_scrambling_magic;
106     hash = (hash >> s_bit_shift_magic) ^ hash;
107     return (uint64_t)hash;
108 }
109 
aws_event_stream_rpc_streamid_eq(const void * a,const void * b)110 bool aws_event_stream_rpc_streamid_eq(const void *a, const void *b) {
111     return *(const uint32_t *)a == *(const uint32_t *)b;
112 }
113