1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 #include <sys/param.h>
17 #include <stdint.h>
18 
19 #include "tls/extensions/s2n_client_supported_groups.h"
20 #include "tls/extensions/s2n_ec_point_format.h"
21 
22 #include "tls/s2n_tls.h"
23 #include "tls/s2n_tls_parameters.h"
24 #include "tls/s2n_security_policies.h"
25 
26 #include "utils/s2n_safety.h"
27 #include "pq-crypto/s2n_pq.h"
28 #include "tls/s2n_tls13.h"
29 
30 static int s2n_client_supported_groups_send(struct s2n_connection *conn, struct s2n_stuffer *out);
31 static int s2n_client_supported_groups_recv(struct s2n_connection *conn, struct s2n_stuffer *extension);
32 
33 const s2n_extension_type s2n_client_supported_groups_extension = {
34     .iana_value = TLS_EXTENSION_SUPPORTED_GROUPS,
35     .is_response = false,
36     .send = s2n_client_supported_groups_send,
37     .recv = s2n_client_supported_groups_recv,
38     .should_send = s2n_extension_should_send_if_ecc_enabled,
39     .if_missing = s2n_extension_noop_if_missing,
40 };
41 
s2n_extension_should_send_if_ecc_enabled(struct s2n_connection * conn)42 bool s2n_extension_should_send_if_ecc_enabled(struct s2n_connection *conn)
43 {
44     const struct s2n_security_policy *security_policy;
45     return s2n_connection_get_security_policy(conn, &security_policy) == S2N_SUCCESS
46             && s2n_ecc_is_extension_required(security_policy);
47 }
48 
s2n_client_supported_groups_send(struct s2n_connection * conn,struct s2n_stuffer * out)49 static int s2n_client_supported_groups_send(struct s2n_connection *conn, struct s2n_stuffer *out)
50 {
51     POSIX_ENSURE_REF(conn);
52 
53     const struct s2n_ecc_preferences *ecc_pref = NULL;
54     POSIX_GUARD(s2n_connection_get_ecc_preferences(conn, &ecc_pref));
55     POSIX_ENSURE_REF(ecc_pref);
56 
57     const struct s2n_kem_preferences *kem_pref = NULL;
58     POSIX_GUARD(s2n_connection_get_kem_preferences(conn, &kem_pref));
59     POSIX_ENSURE_REF(kem_pref);
60 
61     /* Group list len */
62     struct s2n_stuffer_reservation group_list_len = { 0 };
63     POSIX_GUARD(s2n_stuffer_reserve_uint16(out, &group_list_len));
64 
65     /* Send KEM groups list first */
66     if (s2n_connection_get_protocol_version(conn) >= S2N_TLS13 && s2n_pq_is_enabled()) {
67         for (size_t i = 0; i < kem_pref->tls13_kem_group_count; i++) {
68             POSIX_GUARD(s2n_stuffer_write_uint16(out, kem_pref->tls13_kem_groups[i]->iana_id));
69         }
70     }
71 
72     /* Then send curve list */
73     for (size_t i = 0; i < ecc_pref->count; i++) {
74         POSIX_GUARD(s2n_stuffer_write_uint16(out, ecc_pref->ecc_curves[i]->iana_id));
75     }
76 
77     POSIX_GUARD(s2n_stuffer_write_vector_size(&group_list_len));
78 
79     return S2N_SUCCESS;
80 }
81 
82 /* Populates the appropriate index of either the mutually_supported_curves or
83  * mutually_supported_kem_groups array based on the received IANA ID. Will
84  * ignore unrecognized IANA IDs (and return success). */
s2n_client_supported_groups_recv_iana_id(struct s2n_connection * conn,uint16_t iana_id)85 static int s2n_client_supported_groups_recv_iana_id(struct s2n_connection *conn, uint16_t iana_id) {
86     POSIX_ENSURE_REF(conn);
87 
88     const struct s2n_ecc_preferences *ecc_pref = NULL;
89     POSIX_GUARD(s2n_connection_get_ecc_preferences(conn, &ecc_pref));
90     POSIX_ENSURE_REF(ecc_pref);
91 
92     for (size_t i = 0; i < ecc_pref->count; i++) {
93         const struct s2n_ecc_named_curve *supported_curve = ecc_pref->ecc_curves[i];
94         if (iana_id == supported_curve->iana_id) {
95             conn->kex_params.mutually_supported_curves[i] = supported_curve;
96             return S2N_SUCCESS;
97         }
98     }
99 
100     /* Return early if PQ is disabled, or if TLS version is less than 1.3, so as to ignore PQ IDs */
101     if (!s2n_pq_is_enabled() || s2n_connection_get_protocol_version(conn) < S2N_TLS13) {
102         return S2N_SUCCESS;
103     }
104 
105     const struct s2n_kem_preferences *kem_pref = NULL;
106     POSIX_GUARD(s2n_connection_get_kem_preferences(conn, &kem_pref));
107     POSIX_ENSURE_REF(kem_pref);
108 
109     for (size_t i = 0; i < kem_pref->tls13_kem_group_count; i++) {
110         const struct s2n_kem_group *supported_kem_group = kem_pref->tls13_kem_groups[i];
111         if (iana_id == supported_kem_group->iana_id) {
112             conn->kex_params.mutually_supported_kem_groups[i] = supported_kem_group;
113             return S2N_SUCCESS;
114         }
115     }
116 
117     return S2N_SUCCESS;
118 }
119 
s2n_choose_supported_group(struct s2n_connection * conn)120 static int s2n_choose_supported_group(struct s2n_connection *conn) {
121     POSIX_ENSURE_REF(conn);
122 
123     const struct s2n_ecc_preferences *ecc_pref = NULL;
124     POSIX_GUARD(s2n_connection_get_ecc_preferences(conn, &ecc_pref));
125     POSIX_ENSURE_REF(ecc_pref);
126 
127     const struct s2n_kem_preferences *kem_pref = NULL;
128     POSIX_GUARD(s2n_connection_get_kem_preferences(conn, &kem_pref));
129     POSIX_ENSURE_REF(kem_pref);
130 
131     /* Ensure that only the intended group will be non-NULL (if no group is chosen, everything
132      * should be NULL). */
133     conn->kex_params.server_kem_group_params.kem_group = NULL;
134     conn->kex_params.server_kem_group_params.ecc_params.negotiated_curve = NULL;
135     conn->kex_params.server_kem_group_params.kem_params.kem = NULL;
136     conn->kex_params.server_ecc_evp_params.negotiated_curve = NULL;
137 
138     /* Prefer to negotiate hybrid PQ over ECC. If PQ is disabled, we will never choose a
139      * PQ group because the mutually_supported_kem_groups array will not have been
140      * populated with anything. */
141     for (size_t i = 0; i < kem_pref->tls13_kem_group_count; i++) {
142         const struct s2n_kem_group *candidate_kem_group = conn->kex_params.mutually_supported_kem_groups[i];
143         if (candidate_kem_group != NULL) {
144             conn->kex_params.server_kem_group_params.kem_group = candidate_kem_group;
145             conn->kex_params.server_kem_group_params.ecc_params.negotiated_curve = candidate_kem_group->curve;
146             conn->kex_params.server_kem_group_params.kem_params.kem = candidate_kem_group->kem;
147             return S2N_SUCCESS;
148         }
149     }
150 
151     for (size_t i = 0; i < ecc_pref->count; i++) {
152         const struct s2n_ecc_named_curve *candidate_curve = conn->kex_params.mutually_supported_curves[i];
153         if (candidate_curve != NULL) {
154             conn->kex_params.server_ecc_evp_params.negotiated_curve = candidate_curve;
155             return S2N_SUCCESS;
156         }
157     }
158 
159     return S2N_SUCCESS;
160 }
161 
s2n_client_supported_groups_recv(struct s2n_connection * conn,struct s2n_stuffer * extension)162 static int s2n_client_supported_groups_recv(struct s2n_connection *conn, struct s2n_stuffer *extension) {
163     POSIX_ENSURE_REF(conn);
164     POSIX_ENSURE_REF(extension);
165 
166     uint16_t size_of_all;
167     POSIX_GUARD(s2n_stuffer_read_uint16(extension, &size_of_all));
168     if (size_of_all > s2n_stuffer_data_available(extension) || (size_of_all % sizeof(uint16_t))) {
169         /* Malformed length, ignore the extension */
170         return S2N_SUCCESS;
171     }
172 
173     for (size_t i = 0; i < (size_of_all / sizeof(uint16_t)); i++) {
174         uint16_t iana_id;
175         POSIX_GUARD(s2n_stuffer_read_uint16(extension, &iana_id));
176         POSIX_GUARD(s2n_client_supported_groups_recv_iana_id(conn, iana_id));
177     }
178 
179     POSIX_GUARD(s2n_choose_supported_group(conn));
180 
181     return S2N_SUCCESS;
182 }
183 
184 /* Old-style extension functions -- remove after extensions refactor is complete */
185 
s2n_extensions_client_supported_groups_send(struct s2n_connection * conn,struct s2n_stuffer * out)186 int s2n_extensions_client_supported_groups_send(struct s2n_connection *conn, struct s2n_stuffer *out)
187 {
188     POSIX_GUARD(s2n_extension_send(&s2n_client_supported_groups_extension, conn, out));
189 
190     /* The original send method also sent ec point formats. To avoid breaking
191      * anything, I'm going to let it continue writing point formats.
192      */
193     POSIX_GUARD(s2n_extension_send(&s2n_client_ec_point_format_extension, conn, out));
194 
195     return S2N_SUCCESS;
196 }
197 
s2n_recv_client_supported_groups(struct s2n_connection * conn,struct s2n_stuffer * extension)198 int s2n_recv_client_supported_groups(struct s2n_connection *conn, struct s2n_stuffer *extension)
199 {
200     return s2n_extension_recv(&s2n_client_supported_groups_extension, conn, extension);
201 }
202