1 /** @file
2 
3   Management packet marshalling.
4 
5   @section license License
6 
7   Licensed to the Apache Software Foundation (ASF) under one
8   or more contributor license agreements.  See the NOTICE file
9   distributed with this work for additional information
10   regarding copyright ownership.  The ASF licenses this file
11   to you under the Apache License, Version 2.0 (the
12   "License"); you may not use this file except in compliance
13   with the License.  You may obtain a copy of the License at
14 
15       http://www.apache.org/licenses/LICENSE-2.0
16 
17   Unless required by applicable law or agreed to in writing, software
18   distributed under the License is distributed on an "AS IS" BASIS,
19   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20   See the License for the specific language governing permissions and
21   limitations under the License.
22  */
23 
24 #include "tscore/ink_platform.h"
25 #include "tscore/ink_memory.h"
26 #include "tscore/ink_assert.h"
27 #include "MgmtMarshall.h"
28 #include "MgmtSocket.h"
29 
30 union MgmtMarshallAnyPtr {
31   MgmtMarshallInt *m_int;
32   MgmtMarshallLong *m_long;
33   MgmtMarshallString *m_string;
34   MgmtMarshallData *m_data;
35   void *m_void;
36 };
37 
38 static char *empty = const_cast<char *>("");
39 
40 static bool
data_is_nul_terminated(const MgmtMarshallData * data)41 data_is_nul_terminated(const MgmtMarshallData *data)
42 {
43   const char *str = static_cast<const char *>(data->ptr);
44 
45   ink_assert(str);
46   if (str[data->len - 1] != '\0') {
47     return false;
48   }
49 
50   if (strlen(str) != (data->len - 1)) {
51     return false;
52   }
53 
54   return true;
55 }
56 
57 static ssize_t
socket_read_bytes(int fd,void * buf,size_t needed)58 socket_read_bytes(int fd, void *buf, size_t needed)
59 {
60   size_t nread = 0;
61 
62   // makes sure the descriptor is readable
63   if (mgmt_read_timeout(fd, MAX_TIME_WAIT, 0) <= 0) {
64     return -1;
65   }
66 
67   while (needed > nread) {
68     ssize_t ret = read(fd, buf, needed - nread);
69 
70     if (ret < 0) {
71       if (mgmt_transient_error()) {
72         continue;
73       } else {
74         return -1;
75       }
76     }
77 
78     if (ret == 0) {
79       // End of file before reading the remaining bytes.
80       errno = ECONNRESET;
81       return -1;
82     }
83 
84     buf = static_cast<uint8_t *>(buf) + ret;
85     nread += ret;
86   }
87 
88   return nread;
89 }
90 
91 static ssize_t
socket_write_bytes(int fd,const void * buf,ssize_t bytes)92 socket_write_bytes(int fd, const void *buf, ssize_t bytes)
93 {
94   ssize_t nwritten = 0;
95 
96   // makes sure the descriptor is writable
97   if (mgmt_write_timeout(fd, MAX_TIME_WAIT, 0) <= 0) {
98     return -1;
99   }
100 
101   // write until we fulfill the number
102   while (nwritten < bytes) {
103     ssize_t ret = write(fd, buf, bytes - nwritten);
104     if (ret < 0) {
105       if (mgmt_transient_error()) {
106         continue;
107       }
108       return -1;
109     }
110 
111     buf = (uint8_t *)buf + ret;
112     nwritten += ret;
113   }
114 
115   return nwritten;
116 }
117 
118 static ssize_t
socket_write_buffer(int fd,const MgmtMarshallData * data)119 socket_write_buffer(int fd, const MgmtMarshallData *data)
120 {
121   ssize_t nwrite;
122 
123   nwrite = socket_write_bytes(fd, &(data->len), 4);
124   if (nwrite != 4) {
125     goto fail;
126   }
127 
128   if (data->len) {
129     nwrite = socket_write_bytes(fd, data->ptr, data->len);
130     if (nwrite != static_cast<ssize_t>(data->len)) {
131       goto fail;
132     }
133   }
134 
135   return data->len + 4;
136 
137 fail:
138   return -1;
139 }
140 
141 static ssize_t
socket_read_buffer(int fd,MgmtMarshallData * data)142 socket_read_buffer(int fd, MgmtMarshallData *data)
143 {
144   ssize_t nread;
145 
146   ink_zero(*data);
147 
148   nread = socket_read_bytes(fd, &(data->len), 4);
149   if (nread != 4) {
150     goto fail;
151   }
152 
153   if (data->len) {
154     data->ptr = ats_malloc(data->len);
155     nread     = socket_read_bytes(fd, data->ptr, data->len);
156     if (nread != static_cast<ssize_t>(data->len)) {
157       goto fail;
158     }
159   }
160 
161   return data->len + 4;
162 
163 fail:
164   ats_free(data->ptr);
165   ink_zero(*data);
166   return -1;
167 }
168 
169 static ssize_t
buffer_read_buffer(const uint8_t * buf,size_t len,MgmtMarshallData * data)170 buffer_read_buffer(const uint8_t *buf, size_t len, MgmtMarshallData *data)
171 {
172   ink_zero(*data);
173 
174   if (len < 4) {
175     goto fail;
176   }
177 
178   memcpy(&(data->len), buf, 4);
179   buf += 4;
180   len -= 4;
181 
182   if (len < data->len) {
183     goto fail;
184   }
185 
186   if (data->len) {
187     data->ptr = ats_malloc(data->len);
188     memcpy(data->ptr, buf, data->len);
189   }
190 
191   return data->len + 4;
192 
193 fail:
194   ats_free(data->ptr);
195   ink_zero(*data);
196   return -1;
197 }
198 
199 MgmtMarshallInt
mgmt_message_length(const MgmtMarshallType * fields,unsigned count,...)200 mgmt_message_length(const MgmtMarshallType *fields, unsigned count, ...)
201 {
202   MgmtMarshallInt length;
203   va_list ap;
204 
205   va_start(ap, count);
206   length = mgmt_message_length_v(fields, count, ap);
207   va_end(ap);
208 
209   return length;
210 }
211 
212 MgmtMarshallInt
mgmt_message_length_v(const MgmtMarshallType * fields,unsigned count,va_list ap)213 mgmt_message_length_v(const MgmtMarshallType *fields, unsigned count, va_list ap)
214 {
215   MgmtMarshallAnyPtr ptr;
216   MgmtMarshallInt nbytes = 0;
217 
218   for (unsigned n = 0; n < count; ++n) {
219     switch (fields[n]) {
220     case MGMT_MARSHALL_INT:
221       ptr.m_int = va_arg(ap, MgmtMarshallInt *);
222       nbytes += 4;
223       break;
224     case MGMT_MARSHALL_LONG:
225       ptr.m_long = va_arg(ap, MgmtMarshallLong *);
226       nbytes += 8;
227       break;
228     case MGMT_MARSHALL_STRING:
229       nbytes += 4;
230       ptr.m_string = va_arg(ap, MgmtMarshallString *);
231       if (*ptr.m_string == nullptr) {
232         ptr.m_string = &empty;
233       }
234       nbytes += strlen(*ptr.m_string) + 1;
235       break;
236     case MGMT_MARSHALL_DATA:
237       nbytes += 4;
238       ptr.m_data = va_arg(ap, MgmtMarshallData *);
239       nbytes += ptr.m_data->len;
240       break;
241     default:
242       errno = EINVAL;
243       return -1;
244     }
245   }
246 
247   return nbytes;
248 }
249 
250 ssize_t
mgmt_message_write(int fd,const MgmtMarshallType * fields,unsigned count,...)251 mgmt_message_write(int fd, const MgmtMarshallType *fields, unsigned count, ...)
252 {
253   ssize_t nbytes;
254   va_list ap;
255 
256   va_start(ap, count);
257   nbytes = mgmt_message_write_v(fd, fields, count, ap);
258   va_end(ap);
259 
260   return nbytes;
261 }
262 
263 ssize_t
mgmt_message_write_v(int fd,const MgmtMarshallType * fields,unsigned count,va_list ap)264 mgmt_message_write_v(int fd, const MgmtMarshallType *fields, unsigned count, va_list ap)
265 {
266   MgmtMarshallAnyPtr ptr;
267   ssize_t nbytes = 0;
268 
269   for (unsigned n = 0; n < count; ++n) {
270     ssize_t nwritten = 0;
271 
272     switch (fields[n]) {
273     case MGMT_MARSHALL_INT:
274       ptr.m_int = va_arg(ap, MgmtMarshallInt *);
275       nwritten  = socket_write_bytes(fd, ptr.m_void, 4);
276       break;
277     case MGMT_MARSHALL_LONG:
278       ptr.m_long = va_arg(ap, MgmtMarshallLong *);
279       nwritten   = socket_write_bytes(fd, ptr.m_void, 8);
280       break;
281     case MGMT_MARSHALL_STRING: {
282       MgmtMarshallData data;
283       ptr.m_string = va_arg(ap, MgmtMarshallString *);
284       if (*ptr.m_string == nullptr) {
285         ptr.m_string = &empty;
286       }
287       data.ptr = *ptr.m_string;
288       data.len = strlen(*ptr.m_string) + 1;
289       nwritten = socket_write_buffer(fd, &data);
290       break;
291     }
292     case MGMT_MARSHALL_DATA:
293       ptr.m_data = va_arg(ap, MgmtMarshallData *);
294       nwritten   = socket_write_buffer(fd, ptr.m_data);
295       break;
296     default:
297       errno = EINVAL;
298       return -1;
299     }
300 
301     if (nwritten == -1) {
302       return -1;
303     }
304 
305     nbytes += nwritten;
306   }
307 
308   return nbytes;
309 }
310 
311 ssize_t
mgmt_message_read(int fd,const MgmtMarshallType * fields,unsigned count,...)312 mgmt_message_read(int fd, const MgmtMarshallType *fields, unsigned count, ...)
313 {
314   ssize_t nbytes;
315   va_list ap;
316 
317   va_start(ap, count);
318   nbytes = mgmt_message_read_v(fd, fields, count, ap);
319   va_end(ap);
320 
321   return nbytes;
322 }
323 
324 ssize_t
mgmt_message_read_v(int fd,const MgmtMarshallType * fields,unsigned count,va_list ap)325 mgmt_message_read_v(int fd, const MgmtMarshallType *fields, unsigned count, va_list ap)
326 {
327   MgmtMarshallAnyPtr ptr;
328   ssize_t nbytes = 0;
329 
330   for (unsigned n = 0; n < count; ++n) {
331     ssize_t nread;
332 
333     switch (fields[n]) {
334     case MGMT_MARSHALL_INT:
335       ptr.m_int = va_arg(ap, MgmtMarshallInt *);
336       nread     = socket_read_bytes(fd, ptr.m_void, 4);
337       break;
338     case MGMT_MARSHALL_LONG:
339       ptr.m_long = va_arg(ap, MgmtMarshallLong *);
340       nread      = socket_read_bytes(fd, ptr.m_void, 8);
341       break;
342     case MGMT_MARSHALL_STRING: {
343       MgmtMarshallData data;
344 
345       nread = socket_read_buffer(fd, &data);
346       if (nread == -1) {
347         break;
348       }
349 
350       ink_assert(data_is_nul_terminated(&data));
351       ptr.m_string  = va_arg(ap, MgmtMarshallString *);
352       *ptr.m_string = static_cast<char *>(data.ptr);
353       break;
354     }
355     case MGMT_MARSHALL_DATA:
356       ptr.m_data = va_arg(ap, MgmtMarshallData *);
357       nread      = socket_read_buffer(fd, ptr.m_data);
358       break;
359     default:
360       errno = EINVAL;
361       return -1;
362     }
363 
364     if (nread == -1) {
365       return -1;
366     }
367 
368     nbytes += nread;
369   }
370 
371   return nbytes;
372 }
373 
374 ssize_t
mgmt_message_marshall(void * buf,size_t remain,const MgmtMarshallType * fields,unsigned count,...)375 mgmt_message_marshall(void *buf, size_t remain, const MgmtMarshallType *fields, unsigned count, ...)
376 {
377   ssize_t nbytes = 0;
378   va_list ap;
379 
380   va_start(ap, count);
381   nbytes = mgmt_message_marshall_v(buf, remain, fields, count, ap);
382   va_end(ap);
383 
384   return nbytes;
385 }
386 
387 ssize_t
mgmt_message_marshall_v(void * buf,size_t remain,const MgmtMarshallType * fields,unsigned count,va_list ap)388 mgmt_message_marshall_v(void *buf, size_t remain, const MgmtMarshallType *fields, unsigned count, va_list ap)
389 {
390   MgmtMarshallAnyPtr ptr;
391   ssize_t nbytes = 0;
392 
393   for (unsigned n = 0; n < count; ++n) {
394     ssize_t nwritten = 0;
395 
396     switch (fields[n]) {
397     case MGMT_MARSHALL_INT:
398       if (remain < 4) {
399         goto nospace;
400       }
401       ptr.m_int = va_arg(ap, MgmtMarshallInt *);
402       memcpy(buf, ptr.m_int, 4);
403       nwritten = 4;
404       break;
405     case MGMT_MARSHALL_LONG:
406       if (remain < 8) {
407         goto nospace;
408       }
409       ptr.m_long = va_arg(ap, MgmtMarshallLong *);
410       memcpy(buf, ptr.m_long, 8);
411       nwritten = 8;
412       break;
413     case MGMT_MARSHALL_STRING: {
414       MgmtMarshallData data;
415       ptr.m_string = va_arg(ap, MgmtMarshallString *);
416       if (*ptr.m_string == nullptr) {
417         ptr.m_string = &empty;
418       }
419 
420       data.ptr = *ptr.m_string;
421       data.len = strlen(*ptr.m_string) + 1;
422 
423       if (remain < (4 + data.len)) {
424         goto nospace;
425       }
426 
427       memcpy(buf, &data.len, 4);
428       memcpy(static_cast<uint8_t *>(buf) + 4, data.ptr, data.len);
429       nwritten = 4 + data.len;
430       break;
431     }
432     case MGMT_MARSHALL_DATA:
433       ptr.m_data = va_arg(ap, MgmtMarshallData *);
434       if (remain < (4 + ptr.m_data->len)) {
435         goto nospace;
436       }
437       memcpy(buf, &(ptr.m_data->len), 4);
438       memcpy(static_cast<uint8_t *>(buf) + 4, ptr.m_data->ptr, ptr.m_data->len);
439       nwritten = 4 + ptr.m_data->len;
440       break;
441     default:
442       errno = EINVAL;
443       return -1;
444     }
445 
446     nbytes += nwritten;
447     buf = static_cast<uint8_t *>(buf) + nwritten;
448     remain -= nwritten;
449   }
450 
451   return nbytes;
452 
453 nospace:
454   errno = EMSGSIZE;
455   return -1;
456 }
457 
458 ssize_t
mgmt_message_parse(const void * buf,size_t len,const MgmtMarshallType * fields,unsigned count,...)459 mgmt_message_parse(const void *buf, size_t len, const MgmtMarshallType *fields, unsigned count, ...)
460 {
461   MgmtMarshallInt nbytes = 0;
462   va_list ap;
463 
464   va_start(ap, count);
465   nbytes = mgmt_message_parse_v(buf, len, fields, count, ap);
466   va_end(ap);
467 
468   return nbytes;
469 }
470 
471 ssize_t
mgmt_message_parse_v(const void * buf,size_t len,const MgmtMarshallType * fields,unsigned count,va_list ap)472 mgmt_message_parse_v(const void *buf, size_t len, const MgmtMarshallType *fields, unsigned count, va_list ap)
473 {
474   MgmtMarshallAnyPtr ptr;
475   ssize_t nbytes = 0;
476 
477   for (unsigned n = 0; n < count; ++n) {
478     ssize_t nread;
479 
480     switch (fields[n]) {
481     case MGMT_MARSHALL_INT:
482       if (len < 4) {
483         goto nospace;
484       }
485       ptr.m_int = va_arg(ap, MgmtMarshallInt *);
486       memcpy(ptr.m_int, buf, 4);
487       nread = 4;
488       break;
489     case MGMT_MARSHALL_LONG:
490       if (len < 8) {
491         goto nospace;
492       }
493       ptr.m_long = va_arg(ap, MgmtMarshallLong *);
494       memcpy(ptr.m_int, buf, 8);
495       nread = 8;
496       break;
497     case MGMT_MARSHALL_STRING: {
498       MgmtMarshallData data;
499       nread = buffer_read_buffer(static_cast<const uint8_t *>(buf), len, &data);
500       if (nread == -1) {
501         goto nospace;
502       }
503 
504       ink_assert(data_is_nul_terminated(&data));
505 
506       ptr.m_string  = va_arg(ap, MgmtMarshallString *);
507       *ptr.m_string = static_cast<char *>(data.ptr);
508       break;
509     }
510     case MGMT_MARSHALL_DATA:
511       ptr.m_data = va_arg(ap, MgmtMarshallData *);
512       nread      = buffer_read_buffer(static_cast<const uint8_t *>(buf), len, ptr.m_data);
513       if (nread == -1) {
514         goto nospace;
515       }
516       break;
517     default:
518       errno = EINVAL;
519       return -1;
520     }
521 
522     nbytes += nread;
523     buf = (uint8_t *)buf + nread;
524     len -= nread;
525   }
526 
527   return nbytes;
528 
529 nospace:
530   errno = EMSGSIZE;
531   return -1;
532 }
533