1 
2 /* interpreters module */
3 /* low-level access to interpreter primitives */
4 
5 #include "Python.h"
6 #include "frameobject.h"
7 #include "interpreteridobject.h"
8 
9 
10 static char *
_copy_raw_string(PyObject * strobj)11 _copy_raw_string(PyObject *strobj)
12 {
13     const char *str = PyUnicode_AsUTF8(strobj);
14     if (str == NULL) {
15         return NULL;
16     }
17     char *copied = PyMem_Malloc(strlen(str)+1);
18     if (copied == NULL) {
19         PyErr_NoMemory();
20         return NULL;
21     }
22     strcpy(copied, str);
23     return copied;
24 }
25 
26 static PyInterpreterState *
_get_current(void)27 _get_current(void)
28 {
29     // PyInterpreterState_Get() aborts if lookup fails, so don't need
30     // to check the result for NULL.
31     return PyInterpreterState_Get();
32 }
33 
34 
35 /* data-sharing-specific code ***********************************************/
36 
37 struct _sharednsitem {
38     char *name;
39     _PyCrossInterpreterData data;
40 };
41 
42 static void _sharednsitem_clear(struct _sharednsitem *);  // forward
43 
44 static int
_sharednsitem_init(struct _sharednsitem * item,PyObject * key,PyObject * value)45 _sharednsitem_init(struct _sharednsitem *item, PyObject *key, PyObject *value)
46 {
47     item->name = _copy_raw_string(key);
48     if (item->name == NULL) {
49         return -1;
50     }
51     if (_PyObject_GetCrossInterpreterData(value, &item->data) != 0) {
52         _sharednsitem_clear(item);
53         return -1;
54     }
55     return 0;
56 }
57 
58 static void
_sharednsitem_clear(struct _sharednsitem * item)59 _sharednsitem_clear(struct _sharednsitem *item)
60 {
61     if (item->name != NULL) {
62         PyMem_Free(item->name);
63         item->name = NULL;
64     }
65     _PyCrossInterpreterData_Release(&item->data);
66 }
67 
68 static int
_sharednsitem_apply(struct _sharednsitem * item,PyObject * ns)69 _sharednsitem_apply(struct _sharednsitem *item, PyObject *ns)
70 {
71     PyObject *name = PyUnicode_FromString(item->name);
72     if (name == NULL) {
73         return -1;
74     }
75     PyObject *value = _PyCrossInterpreterData_NewObject(&item->data);
76     if (value == NULL) {
77         Py_DECREF(name);
78         return -1;
79     }
80     int res = PyDict_SetItem(ns, name, value);
81     Py_DECREF(name);
82     Py_DECREF(value);
83     return res;
84 }
85 
86 typedef struct _sharedns {
87     Py_ssize_t len;
88     struct _sharednsitem* items;
89 } _sharedns;
90 
91 static _sharedns *
_sharedns_new(Py_ssize_t len)92 _sharedns_new(Py_ssize_t len)
93 {
94     _sharedns *shared = PyMem_NEW(_sharedns, 1);
95     if (shared == NULL) {
96         PyErr_NoMemory();
97         return NULL;
98     }
99     shared->len = len;
100     shared->items = PyMem_NEW(struct _sharednsitem, len);
101     if (shared->items == NULL) {
102         PyErr_NoMemory();
103         PyMem_Free(shared);
104         return NULL;
105     }
106     return shared;
107 }
108 
109 static void
_sharedns_free(_sharedns * shared)110 _sharedns_free(_sharedns *shared)
111 {
112     for (Py_ssize_t i=0; i < shared->len; i++) {
113         _sharednsitem_clear(&shared->items[i]);
114     }
115     PyMem_Free(shared->items);
116     PyMem_Free(shared);
117 }
118 
119 static _sharedns *
_get_shared_ns(PyObject * shareable)120 _get_shared_ns(PyObject *shareable)
121 {
122     if (shareable == NULL || shareable == Py_None) {
123         return NULL;
124     }
125     Py_ssize_t len = PyDict_Size(shareable);
126     if (len == 0) {
127         return NULL;
128     }
129 
130     _sharedns *shared = _sharedns_new(len);
131     if (shared == NULL) {
132         return NULL;
133     }
134     Py_ssize_t pos = 0;
135     for (Py_ssize_t i=0; i < len; i++) {
136         PyObject *key, *value;
137         if (PyDict_Next(shareable, &pos, &key, &value) == 0) {
138             break;
139         }
140         if (_sharednsitem_init(&shared->items[i], key, value) != 0) {
141             break;
142         }
143     }
144     if (PyErr_Occurred()) {
145         _sharedns_free(shared);
146         return NULL;
147     }
148     return shared;
149 }
150 
151 static int
_sharedns_apply(_sharedns * shared,PyObject * ns)152 _sharedns_apply(_sharedns *shared, PyObject *ns)
153 {
154     for (Py_ssize_t i=0; i < shared->len; i++) {
155         if (_sharednsitem_apply(&shared->items[i], ns) != 0) {
156             return -1;
157         }
158     }
159     return 0;
160 }
161 
162 // Ultimately we'd like to preserve enough information about the
163 // exception and traceback that we could re-constitute (or at least
164 // simulate, a la traceback.TracebackException), and even chain, a copy
165 // of the exception in the calling interpreter.
166 
167 typedef struct _sharedexception {
168     char *name;
169     char *msg;
170 } _sharedexception;
171 
172 static _sharedexception *
_sharedexception_new(void)173 _sharedexception_new(void)
174 {
175     _sharedexception *err = PyMem_NEW(_sharedexception, 1);
176     if (err == NULL) {
177         PyErr_NoMemory();
178         return NULL;
179     }
180     err->name = NULL;
181     err->msg = NULL;
182     return err;
183 }
184 
185 static void
_sharedexception_clear(_sharedexception * exc)186 _sharedexception_clear(_sharedexception *exc)
187 {
188     if (exc->name != NULL) {
189         PyMem_Free(exc->name);
190     }
191     if (exc->msg != NULL) {
192         PyMem_Free(exc->msg);
193     }
194 }
195 
196 static void
_sharedexception_free(_sharedexception * exc)197 _sharedexception_free(_sharedexception *exc)
198 {
199     _sharedexception_clear(exc);
200     PyMem_Free(exc);
201 }
202 
203 static _sharedexception *
_sharedexception_bind(PyObject * exctype,PyObject * exc,PyObject * tb)204 _sharedexception_bind(PyObject *exctype, PyObject *exc, PyObject *tb)
205 {
206     assert(exctype != NULL);
207     char *failure = NULL;
208 
209     _sharedexception *err = _sharedexception_new();
210     if (err == NULL) {
211         goto finally;
212     }
213 
214     PyObject *name = PyUnicode_FromFormat("%S", exctype);
215     if (name == NULL) {
216         failure = "unable to format exception type name";
217         goto finally;
218     }
219     err->name = _copy_raw_string(name);
220     Py_DECREF(name);
221     if (err->name == NULL) {
222         if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
223             failure = "out of memory copying exception type name";
224         } else {
225             failure = "unable to encode and copy exception type name";
226         }
227         goto finally;
228     }
229 
230     if (exc != NULL) {
231         PyObject *msg = PyUnicode_FromFormat("%S", exc);
232         if (msg == NULL) {
233             failure = "unable to format exception message";
234             goto finally;
235         }
236         err->msg = _copy_raw_string(msg);
237         Py_DECREF(msg);
238         if (err->msg == NULL) {
239             if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
240                 failure = "out of memory copying exception message";
241             } else {
242                 failure = "unable to encode and copy exception message";
243             }
244             goto finally;
245         }
246     }
247 
248 finally:
249     if (failure != NULL) {
250         PyErr_Clear();
251         if (err->name != NULL) {
252             PyMem_Free(err->name);
253             err->name = NULL;
254         }
255         err->msg = failure;
256     }
257     return err;
258 }
259 
260 static void
_sharedexception_apply(_sharedexception * exc,PyObject * wrapperclass)261 _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
262 {
263     if (exc->name != NULL) {
264         if (exc->msg != NULL) {
265             PyErr_Format(wrapperclass, "%s: %s",  exc->name, exc->msg);
266         }
267         else {
268             PyErr_SetString(wrapperclass, exc->name);
269         }
270     }
271     else if (exc->msg != NULL) {
272         PyErr_SetString(wrapperclass, exc->msg);
273     }
274     else {
275         PyErr_SetNone(wrapperclass);
276     }
277 }
278 
279 
280 /* channel-specific code ****************************************************/
281 
282 #define CHANNEL_SEND 1
283 #define CHANNEL_BOTH 0
284 #define CHANNEL_RECV -1
285 
286 static PyObject *ChannelError;
287 static PyObject *ChannelNotFoundError;
288 static PyObject *ChannelClosedError;
289 static PyObject *ChannelEmptyError;
290 static PyObject *ChannelNotEmptyError;
291 
292 static int
channel_exceptions_init(PyObject * ns)293 channel_exceptions_init(PyObject *ns)
294 {
295     // XXX Move the exceptions into per-module memory?
296 
297     // A channel-related operation failed.
298     ChannelError = PyErr_NewException("_xxsubinterpreters.ChannelError",
299                                       PyExc_RuntimeError, NULL);
300     if (ChannelError == NULL) {
301         return -1;
302     }
303     if (PyDict_SetItemString(ns, "ChannelError", ChannelError) != 0) {
304         return -1;
305     }
306 
307     // An operation tried to use a channel that doesn't exist.
308     ChannelNotFoundError = PyErr_NewException(
309             "_xxsubinterpreters.ChannelNotFoundError", ChannelError, NULL);
310     if (ChannelNotFoundError == NULL) {
311         return -1;
312     }
313     if (PyDict_SetItemString(ns, "ChannelNotFoundError", ChannelNotFoundError) != 0) {
314         return -1;
315     }
316 
317     // An operation tried to use a closed channel.
318     ChannelClosedError = PyErr_NewException(
319             "_xxsubinterpreters.ChannelClosedError", ChannelError, NULL);
320     if (ChannelClosedError == NULL) {
321         return -1;
322     }
323     if (PyDict_SetItemString(ns, "ChannelClosedError", ChannelClosedError) != 0) {
324         return -1;
325     }
326 
327     // An operation tried to pop from an empty channel.
328     ChannelEmptyError = PyErr_NewException(
329             "_xxsubinterpreters.ChannelEmptyError", ChannelError, NULL);
330     if (ChannelEmptyError == NULL) {
331         return -1;
332     }
333     if (PyDict_SetItemString(ns, "ChannelEmptyError", ChannelEmptyError) != 0) {
334         return -1;
335     }
336 
337     // An operation tried to close a non-empty channel.
338     ChannelNotEmptyError = PyErr_NewException(
339             "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
340     if (ChannelNotEmptyError == NULL) {
341         return -1;
342     }
343     if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
344         return -1;
345     }
346 
347     return 0;
348 }
349 
350 /* the channel queue */
351 
352 struct _channelitem;
353 
354 typedef struct _channelitem {
355     _PyCrossInterpreterData *data;
356     struct _channelitem *next;
357 } _channelitem;
358 
359 static _channelitem *
_channelitem_new(void)360 _channelitem_new(void)
361 {
362     _channelitem *item = PyMem_NEW(_channelitem, 1);
363     if (item == NULL) {
364         PyErr_NoMemory();
365         return NULL;
366     }
367     item->data = NULL;
368     item->next = NULL;
369     return item;
370 }
371 
372 static void
_channelitem_clear(_channelitem * item)373 _channelitem_clear(_channelitem *item)
374 {
375     if (item->data != NULL) {
376         _PyCrossInterpreterData_Release(item->data);
377         PyMem_Free(item->data);
378         item->data = NULL;
379     }
380     item->next = NULL;
381 }
382 
383 static void
_channelitem_free(_channelitem * item)384 _channelitem_free(_channelitem *item)
385 {
386     _channelitem_clear(item);
387     PyMem_Free(item);
388 }
389 
390 static void
_channelitem_free_all(_channelitem * item)391 _channelitem_free_all(_channelitem *item)
392 {
393     while (item != NULL) {
394         _channelitem *last = item;
395         item = item->next;
396         _channelitem_free(last);
397     }
398 }
399 
400 static _PyCrossInterpreterData *
_channelitem_popped(_channelitem * item)401 _channelitem_popped(_channelitem *item)
402 {
403     _PyCrossInterpreterData *data = item->data;
404     item->data = NULL;
405     _channelitem_free(item);
406     return data;
407 }
408 
409 typedef struct _channelqueue {
410     int64_t count;
411     _channelitem *first;
412     _channelitem *last;
413 } _channelqueue;
414 
415 static _channelqueue *
_channelqueue_new(void)416 _channelqueue_new(void)
417 {
418     _channelqueue *queue = PyMem_NEW(_channelqueue, 1);
419     if (queue == NULL) {
420         PyErr_NoMemory();
421         return NULL;
422     }
423     queue->count = 0;
424     queue->first = NULL;
425     queue->last = NULL;
426     return queue;
427 }
428 
429 static void
_channelqueue_clear(_channelqueue * queue)430 _channelqueue_clear(_channelqueue *queue)
431 {
432     _channelitem_free_all(queue->first);
433     queue->count = 0;
434     queue->first = NULL;
435     queue->last = NULL;
436 }
437 
438 static void
_channelqueue_free(_channelqueue * queue)439 _channelqueue_free(_channelqueue *queue)
440 {
441     _channelqueue_clear(queue);
442     PyMem_Free(queue);
443 }
444 
445 static int
_channelqueue_put(_channelqueue * queue,_PyCrossInterpreterData * data)446 _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
447 {
448     _channelitem *item = _channelitem_new();
449     if (item == NULL) {
450         return -1;
451     }
452     item->data = data;
453 
454     queue->count += 1;
455     if (queue->first == NULL) {
456         queue->first = item;
457     }
458     else {
459         queue->last->next = item;
460     }
461     queue->last = item;
462     return 0;
463 }
464 
465 static _PyCrossInterpreterData *
_channelqueue_get(_channelqueue * queue)466 _channelqueue_get(_channelqueue *queue)
467 {
468     _channelitem *item = queue->first;
469     if (item == NULL) {
470         return NULL;
471     }
472     queue->first = item->next;
473     if (queue->last == item) {
474         queue->last = NULL;
475     }
476     queue->count -= 1;
477 
478     return _channelitem_popped(item);
479 }
480 
481 /* channel-interpreter associations */
482 
483 struct _channelend;
484 
485 typedef struct _channelend {
486     struct _channelend *next;
487     int64_t interp;
488     int open;
489 } _channelend;
490 
491 static _channelend *
_channelend_new(int64_t interp)492 _channelend_new(int64_t interp)
493 {
494     _channelend *end = PyMem_NEW(_channelend, 1);
495     if (end == NULL) {
496         PyErr_NoMemory();
497         return NULL;
498     }
499     end->next = NULL;
500     end->interp = interp;
501     end->open = 1;
502     return end;
503 }
504 
505 static void
_channelend_free(_channelend * end)506 _channelend_free(_channelend *end)
507 {
508     PyMem_Free(end);
509 }
510 
511 static void
_channelend_free_all(_channelend * end)512 _channelend_free_all(_channelend *end)
513 {
514     while (end != NULL) {
515         _channelend *last = end;
516         end = end->next;
517         _channelend_free(last);
518     }
519 }
520 
521 static _channelend *
_channelend_find(_channelend * first,int64_t interp,_channelend ** pprev)522 _channelend_find(_channelend *first, int64_t interp, _channelend **pprev)
523 {
524     _channelend *prev = NULL;
525     _channelend *end = first;
526     while (end != NULL) {
527         if (end->interp == interp) {
528             break;
529         }
530         prev = end;
531         end = end->next;
532     }
533     if (pprev != NULL) {
534         *pprev = prev;
535     }
536     return end;
537 }
538 
539 typedef struct _channelassociations {
540     // Note that the list entries are never removed for interpreter
541     // for which the channel is closed.  This should not be a problem in
542     // practice.  Also, a channel isn't automatically closed when an
543     // interpreter is destroyed.
544     int64_t numsendopen;
545     int64_t numrecvopen;
546     _channelend *send;
547     _channelend *recv;
548 } _channelends;
549 
550 static _channelends *
_channelends_new(void)551 _channelends_new(void)
552 {
553     _channelends *ends = PyMem_NEW(_channelends, 1);
554     if (ends== NULL) {
555         return NULL;
556     }
557     ends->numsendopen = 0;
558     ends->numrecvopen = 0;
559     ends->send = NULL;
560     ends->recv = NULL;
561     return ends;
562 }
563 
564 static void
_channelends_clear(_channelends * ends)565 _channelends_clear(_channelends *ends)
566 {
567     _channelend_free_all(ends->send);
568     ends->send = NULL;
569     ends->numsendopen = 0;
570 
571     _channelend_free_all(ends->recv);
572     ends->recv = NULL;
573     ends->numrecvopen = 0;
574 }
575 
576 static void
_channelends_free(_channelends * ends)577 _channelends_free(_channelends *ends)
578 {
579     _channelends_clear(ends);
580     PyMem_Free(ends);
581 }
582 
583 static _channelend *
_channelends_add(_channelends * ends,_channelend * prev,int64_t interp,int send)584 _channelends_add(_channelends *ends, _channelend *prev, int64_t interp,
585                  int send)
586 {
587     _channelend *end = _channelend_new(interp);
588     if (end == NULL) {
589         return NULL;
590     }
591 
592     if (prev == NULL) {
593         if (send) {
594             ends->send = end;
595         }
596         else {
597             ends->recv = end;
598         }
599     }
600     else {
601         prev->next = end;
602     }
603     if (send) {
604         ends->numsendopen += 1;
605     }
606     else {
607         ends->numrecvopen += 1;
608     }
609     return end;
610 }
611 
612 static int
_channelends_associate(_channelends * ends,int64_t interp,int send)613 _channelends_associate(_channelends *ends, int64_t interp, int send)
614 {
615     _channelend *prev;
616     _channelend *end = _channelend_find(send ? ends->send : ends->recv,
617                                         interp, &prev);
618     if (end != NULL) {
619         if (!end->open) {
620             PyErr_SetString(ChannelClosedError, "channel already closed");
621             return -1;
622         }
623         // already associated
624         return 0;
625     }
626     if (_channelends_add(ends, prev, interp, send) == NULL) {
627         return -1;
628     }
629     return 0;
630 }
631 
632 static int
_channelends_is_open(_channelends * ends)633 _channelends_is_open(_channelends *ends)
634 {
635     if (ends->numsendopen != 0 || ends->numrecvopen != 0) {
636         return 1;
637     }
638     if (ends->send == NULL && ends->recv == NULL) {
639         return 1;
640     }
641     return 0;
642 }
643 
644 static void
_channelends_close_end(_channelends * ends,_channelend * end,int send)645 _channelends_close_end(_channelends *ends, _channelend *end, int send)
646 {
647     end->open = 0;
648     if (send) {
649         ends->numsendopen -= 1;
650     }
651     else {
652         ends->numrecvopen -= 1;
653     }
654 }
655 
656 static int
_channelends_close_interpreter(_channelends * ends,int64_t interp,int which)657 _channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
658 {
659     _channelend *prev;
660     _channelend *end;
661     if (which >= 0) {  // send/both
662         end = _channelend_find(ends->send, interp, &prev);
663         if (end == NULL) {
664             // never associated so add it
665             end = _channelends_add(ends, prev, interp, 1);
666             if (end == NULL) {
667                 return -1;
668             }
669         }
670         _channelends_close_end(ends, end, 1);
671     }
672     if (which <= 0) {  // recv/both
673         end = _channelend_find(ends->recv, interp, &prev);
674         if (end == NULL) {
675             // never associated so add it
676             end = _channelends_add(ends, prev, interp, 0);
677             if (end == NULL) {
678                 return -1;
679             }
680         }
681         _channelends_close_end(ends, end, 0);
682     }
683     return 0;
684 }
685 
686 static void
_channelends_close_all(_channelends * ends,int which,int force)687 _channelends_close_all(_channelends *ends, int which, int force)
688 {
689     // XXX Handle the ends.
690     // XXX Handle force is True.
691 
692     // Ensure all the "send"-associated interpreters are closed.
693     _channelend *end;
694     for (end = ends->send; end != NULL; end = end->next) {
695         _channelends_close_end(ends, end, 1);
696     }
697 
698     // Ensure all the "recv"-associated interpreters are closed.
699     for (end = ends->recv; end != NULL; end = end->next) {
700         _channelends_close_end(ends, end, 0);
701     }
702 }
703 
704 /* channels */
705 
706 struct _channel;
707 struct _channel_closing;
708 static void _channel_clear_closing(struct _channel *);
709 static void _channel_finish_closing(struct _channel *);
710 
711 typedef struct _channel {
712     PyThread_type_lock mutex;
713     _channelqueue *queue;
714     _channelends *ends;
715     int open;
716     struct _channel_closing *closing;
717 } _PyChannelState;
718 
719 static _PyChannelState *
_channel_new(void)720 _channel_new(void)
721 {
722     _PyChannelState *chan = PyMem_NEW(_PyChannelState, 1);
723     if (chan == NULL) {
724         return NULL;
725     }
726     chan->mutex = PyThread_allocate_lock();
727     if (chan->mutex == NULL) {
728         PyMem_Free(chan);
729         PyErr_SetString(ChannelError,
730                         "can't initialize mutex for new channel");
731         return NULL;
732     }
733     chan->queue = _channelqueue_new();
734     if (chan->queue == NULL) {
735         PyMem_Free(chan);
736         return NULL;
737     }
738     chan->ends = _channelends_new();
739     if (chan->ends == NULL) {
740         _channelqueue_free(chan->queue);
741         PyMem_Free(chan);
742         return NULL;
743     }
744     chan->open = 1;
745     chan->closing = NULL;
746     return chan;
747 }
748 
749 static void
_channel_free(_PyChannelState * chan)750 _channel_free(_PyChannelState *chan)
751 {
752     _channel_clear_closing(chan);
753     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
754     _channelqueue_free(chan->queue);
755     _channelends_free(chan->ends);
756     PyThread_release_lock(chan->mutex);
757 
758     PyThread_free_lock(chan->mutex);
759     PyMem_Free(chan);
760 }
761 
762 static int
_channel_add(_PyChannelState * chan,int64_t interp,_PyCrossInterpreterData * data)763 _channel_add(_PyChannelState *chan, int64_t interp,
764              _PyCrossInterpreterData *data)
765 {
766     int res = -1;
767     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
768 
769     if (!chan->open) {
770         PyErr_SetString(ChannelClosedError, "channel closed");
771         goto done;
772     }
773     if (_channelends_associate(chan->ends, interp, 1) != 0) {
774         goto done;
775     }
776 
777     if (_channelqueue_put(chan->queue, data) != 0) {
778         goto done;
779     }
780 
781     res = 0;
782 done:
783     PyThread_release_lock(chan->mutex);
784     return res;
785 }
786 
787 static _PyCrossInterpreterData *
_channel_next(_PyChannelState * chan,int64_t interp)788 _channel_next(_PyChannelState *chan, int64_t interp)
789 {
790     _PyCrossInterpreterData *data = NULL;
791     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
792 
793     if (!chan->open) {
794         PyErr_SetString(ChannelClosedError, "channel closed");
795         goto done;
796     }
797     if (_channelends_associate(chan->ends, interp, 0) != 0) {
798         goto done;
799     }
800 
801     data = _channelqueue_get(chan->queue);
802     if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
803         chan->open = 0;
804     }
805 
806 done:
807     PyThread_release_lock(chan->mutex);
808     if (chan->queue->count == 0) {
809         _channel_finish_closing(chan);
810     }
811     return data;
812 }
813 
814 static int
_channel_close_interpreter(_PyChannelState * chan,int64_t interp,int end)815 _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
816 {
817     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
818 
819     int res = -1;
820     if (!chan->open) {
821         PyErr_SetString(ChannelClosedError, "channel already closed");
822         goto done;
823     }
824 
825     if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
826         goto done;
827     }
828     chan->open = _channelends_is_open(chan->ends);
829 
830     res = 0;
831 done:
832     PyThread_release_lock(chan->mutex);
833     return res;
834 }
835 
836 static int
_channel_close_all(_PyChannelState * chan,int end,int force)837 _channel_close_all(_PyChannelState *chan, int end, int force)
838 {
839     int res = -1;
840     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
841 
842     if (!chan->open) {
843         PyErr_SetString(ChannelClosedError, "channel already closed");
844         goto done;
845     }
846 
847     if (!force && chan->queue->count > 0) {
848         PyErr_SetString(ChannelNotEmptyError,
849                         "may not be closed if not empty (try force=True)");
850         goto done;
851     }
852 
853     chan->open = 0;
854 
855     // We *could* also just leave these in place, since we've marked
856     // the channel as closed already.
857     _channelends_close_all(chan->ends, end, force);
858 
859     res = 0;
860 done:
861     PyThread_release_lock(chan->mutex);
862     return res;
863 }
864 
865 /* the set of channels */
866 
867 struct _channelref;
868 
869 typedef struct _channelref {
870     int64_t id;
871     _PyChannelState *chan;
872     struct _channelref *next;
873     Py_ssize_t objcount;
874 } _channelref;
875 
876 static _channelref *
_channelref_new(int64_t id,_PyChannelState * chan)877 _channelref_new(int64_t id, _PyChannelState *chan)
878 {
879     _channelref *ref = PyMem_NEW(_channelref, 1);
880     if (ref == NULL) {
881         return NULL;
882     }
883     ref->id = id;
884     ref->chan = chan;
885     ref->next = NULL;
886     ref->objcount = 0;
887     return ref;
888 }
889 
890 //static void
891 //_channelref_clear(_channelref *ref)
892 //{
893 //    ref->id = -1;
894 //    ref->chan = NULL;
895 //    ref->next = NULL;
896 //    ref->objcount = 0;
897 //}
898 
899 static void
_channelref_free(_channelref * ref)900 _channelref_free(_channelref *ref)
901 {
902     if (ref->chan != NULL) {
903         _channel_clear_closing(ref->chan);
904     }
905     //_channelref_clear(ref);
906     PyMem_Free(ref);
907 }
908 
909 static _channelref *
_channelref_find(_channelref * first,int64_t id,_channelref ** pprev)910 _channelref_find(_channelref *first, int64_t id, _channelref **pprev)
911 {
912     _channelref *prev = NULL;
913     _channelref *ref = first;
914     while (ref != NULL) {
915         if (ref->id == id) {
916             break;
917         }
918         prev = ref;
919         ref = ref->next;
920     }
921     if (pprev != NULL) {
922         *pprev = prev;
923     }
924     return ref;
925 }
926 
927 typedef struct _channels {
928     PyThread_type_lock mutex;
929     _channelref *head;
930     int64_t numopen;
931     int64_t next_id;
932 } _channels;
933 
934 static int
_channels_init(_channels * channels)935 _channels_init(_channels *channels)
936 {
937     if (channels->mutex == NULL) {
938         channels->mutex = PyThread_allocate_lock();
939         if (channels->mutex == NULL) {
940             PyErr_SetString(ChannelError,
941                             "can't initialize mutex for channel management");
942             return -1;
943         }
944     }
945     channels->head = NULL;
946     channels->numopen = 0;
947     channels->next_id = 0;
948     return 0;
949 }
950 
951 static int64_t
_channels_next_id(_channels * channels)952 _channels_next_id(_channels *channels)  // needs lock
953 {
954     int64_t id = channels->next_id;
955     if (id < 0) {
956         /* overflow */
957         PyErr_SetString(ChannelError,
958                         "failed to get a channel ID");
959         return -1;
960     }
961     channels->next_id += 1;
962     return id;
963 }
964 
965 static _PyChannelState *
_channels_lookup(_channels * channels,int64_t id,PyThread_type_lock * pmutex)966 _channels_lookup(_channels *channels, int64_t id, PyThread_type_lock *pmutex)
967 {
968     _PyChannelState *chan = NULL;
969     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
970     if (pmutex != NULL) {
971         *pmutex = NULL;
972     }
973 
974     _channelref *ref = _channelref_find(channels->head, id, NULL);
975     if (ref == NULL) {
976         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id);
977         goto done;
978     }
979     if (ref->chan == NULL || !ref->chan->open) {
980         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", id);
981         goto done;
982     }
983 
984     if (pmutex != NULL) {
985         // The mutex will be closed by the caller.
986         *pmutex = channels->mutex;
987     }
988 
989     chan = ref->chan;
990 done:
991     if (pmutex == NULL || *pmutex == NULL) {
992         PyThread_release_lock(channels->mutex);
993     }
994     return chan;
995 }
996 
997 static int64_t
_channels_add(_channels * channels,_PyChannelState * chan)998 _channels_add(_channels *channels, _PyChannelState *chan)
999 {
1000     int64_t cid = -1;
1001     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1002 
1003     // Create a new ref.
1004     int64_t id = _channels_next_id(channels);
1005     if (id < 0) {
1006         goto done;
1007     }
1008     _channelref *ref = _channelref_new(id, chan);
1009     if (ref == NULL) {
1010         goto done;
1011     }
1012 
1013     // Add it to the list.
1014     // We assume that the channel is a new one (not already in the list).
1015     ref->next = channels->head;
1016     channels->head = ref;
1017     channels->numopen += 1;
1018 
1019     cid = id;
1020 done:
1021     PyThread_release_lock(channels->mutex);
1022     return cid;
1023 }
1024 
1025 /* forward */
1026 static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
1027 
1028 static int
_channels_close(_channels * channels,int64_t cid,_PyChannelState ** pchan,int end,int force)1029 _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
1030                 int end, int force)
1031 {
1032     int res = -1;
1033     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1034     if (pchan != NULL) {
1035         *pchan = NULL;
1036     }
1037 
1038     _channelref *ref = _channelref_find(channels->head, cid, NULL);
1039     if (ref == NULL) {
1040         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", cid);
1041         goto done;
1042     }
1043 
1044     if (ref->chan == NULL) {
1045         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid);
1046         goto done;
1047     }
1048     else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
1049         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid);
1050         goto done;
1051     }
1052     else {
1053         if (_channel_close_all(ref->chan, end, force) != 0) {
1054             if (end == CHANNEL_SEND &&
1055                     PyErr_ExceptionMatches(ChannelNotEmptyError)) {
1056                 if (ref->chan->closing != NULL) {
1057                     PyErr_Format(ChannelClosedError,
1058                                  "channel %" PRId64 " closed", cid);
1059                     goto done;
1060                 }
1061                 // Mark the channel as closing and return.  The channel
1062                 // will be cleaned up in _channel_next().
1063                 PyErr_Clear();
1064                 if (_channel_set_closing(ref, channels->mutex) != 0) {
1065                     goto done;
1066                 }
1067                 if (pchan != NULL) {
1068                     *pchan = ref->chan;
1069                 }
1070                 res = 0;
1071             }
1072             goto done;
1073         }
1074         if (pchan != NULL) {
1075             *pchan = ref->chan;
1076         }
1077         else  {
1078             _channel_free(ref->chan);
1079         }
1080         ref->chan = NULL;
1081     }
1082 
1083     res = 0;
1084 done:
1085     PyThread_release_lock(channels->mutex);
1086     return res;
1087 }
1088 
1089 static void
_channels_remove_ref(_channels * channels,_channelref * ref,_channelref * prev,_PyChannelState ** pchan)1090 _channels_remove_ref(_channels *channels, _channelref *ref, _channelref *prev,
1091                      _PyChannelState **pchan)
1092 {
1093     if (ref == channels->head) {
1094         channels->head = ref->next;
1095     }
1096     else {
1097         prev->next = ref->next;
1098     }
1099     channels->numopen -= 1;
1100 
1101     if (pchan != NULL) {
1102         *pchan = ref->chan;
1103     }
1104     _channelref_free(ref);
1105 }
1106 
1107 static int
_channels_remove(_channels * channels,int64_t id,_PyChannelState ** pchan)1108 _channels_remove(_channels *channels, int64_t id, _PyChannelState **pchan)
1109 {
1110     int res = -1;
1111     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1112 
1113     if (pchan != NULL) {
1114         *pchan = NULL;
1115     }
1116 
1117     _channelref *prev = NULL;
1118     _channelref *ref = _channelref_find(channels->head, id, &prev);
1119     if (ref == NULL) {
1120         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id);
1121         goto done;
1122     }
1123 
1124     _channels_remove_ref(channels, ref, prev, pchan);
1125 
1126     res = 0;
1127 done:
1128     PyThread_release_lock(channels->mutex);
1129     return res;
1130 }
1131 
1132 static int
_channels_add_id_object(_channels * channels,int64_t id)1133 _channels_add_id_object(_channels *channels, int64_t id)
1134 {
1135     int res = -1;
1136     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1137 
1138     _channelref *ref = _channelref_find(channels->head, id, NULL);
1139     if (ref == NULL) {
1140         PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id);
1141         goto done;
1142     }
1143     ref->objcount += 1;
1144 
1145     res = 0;
1146 done:
1147     PyThread_release_lock(channels->mutex);
1148     return res;
1149 }
1150 
1151 static void
_channels_drop_id_object(_channels * channels,int64_t id)1152 _channels_drop_id_object(_channels *channels, int64_t id)
1153 {
1154     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1155 
1156     _channelref *prev = NULL;
1157     _channelref *ref = _channelref_find(channels->head, id, &prev);
1158     if (ref == NULL) {
1159         // Already destroyed.
1160         goto done;
1161     }
1162     ref->objcount -= 1;
1163 
1164     // Destroy if no longer used.
1165     if (ref->objcount == 0) {
1166         _PyChannelState *chan = NULL;
1167         _channels_remove_ref(channels, ref, prev, &chan);
1168         if (chan != NULL) {
1169             _channel_free(chan);
1170         }
1171     }
1172 
1173 done:
1174     PyThread_release_lock(channels->mutex);
1175 }
1176 
1177 static int64_t *
_channels_list_all(_channels * channels,int64_t * count)1178 _channels_list_all(_channels *channels, int64_t *count)
1179 {
1180     int64_t *cids = NULL;
1181     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
1182     int64_t *ids = PyMem_NEW(int64_t, (Py_ssize_t)(channels->numopen));
1183     if (ids == NULL) {
1184         goto done;
1185     }
1186     _channelref *ref = channels->head;
1187     for (int64_t i=0; ref != NULL; ref = ref->next, i++) {
1188         ids[i] = ref->id;
1189     }
1190     *count = channels->numopen;
1191 
1192     cids = ids;
1193 done:
1194     PyThread_release_lock(channels->mutex);
1195     return cids;
1196 }
1197 
1198 /* support for closing non-empty channels */
1199 
1200 struct _channel_closing {
1201     struct _channelref *ref;
1202 };
1203 
1204 static int
_channel_set_closing(struct _channelref * ref,PyThread_type_lock mutex)1205 _channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
1206     struct _channel *chan = ref->chan;
1207     if (chan == NULL) {
1208         // already closed
1209         return 0;
1210     }
1211     int res = -1;
1212     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1213     if (chan->closing != NULL) {
1214         PyErr_SetString(ChannelClosedError, "channel closed");
1215         goto done;
1216     }
1217     chan->closing = PyMem_NEW(struct _channel_closing, 1);
1218     if (chan->closing == NULL) {
1219         goto done;
1220     }
1221     chan->closing->ref = ref;
1222 
1223     res = 0;
1224 done:
1225     PyThread_release_lock(chan->mutex);
1226     return res;
1227 }
1228 
1229 static void
_channel_clear_closing(struct _channel * chan)1230 _channel_clear_closing(struct _channel *chan) {
1231     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
1232     if (chan->closing != NULL) {
1233         PyMem_Free(chan->closing);
1234         chan->closing = NULL;
1235     }
1236     PyThread_release_lock(chan->mutex);
1237 }
1238 
1239 static void
_channel_finish_closing(struct _channel * chan)1240 _channel_finish_closing(struct _channel *chan) {
1241     struct _channel_closing *closing = chan->closing;
1242     if (closing == NULL) {
1243         return;
1244     }
1245     _channelref *ref = closing->ref;
1246     _channel_clear_closing(chan);
1247     // Do the things that would have been done in _channels_close().
1248     ref->chan = NULL;
1249     _channel_free(chan);
1250 }
1251 
1252 /* "high"-level channel-related functions */
1253 
1254 static int64_t
_channel_create(_channels * channels)1255 _channel_create(_channels *channels)
1256 {
1257     _PyChannelState *chan = _channel_new();
1258     if (chan == NULL) {
1259         return -1;
1260     }
1261     int64_t id = _channels_add(channels, chan);
1262     if (id < 0) {
1263         _channel_free(chan);
1264         return -1;
1265     }
1266     return id;
1267 }
1268 
1269 static int
_channel_destroy(_channels * channels,int64_t id)1270 _channel_destroy(_channels *channels, int64_t id)
1271 {
1272     _PyChannelState *chan = NULL;
1273     if (_channels_remove(channels, id, &chan) != 0) {
1274         return -1;
1275     }
1276     if (chan != NULL) {
1277         _channel_free(chan);
1278     }
1279     return 0;
1280 }
1281 
1282 static int
_channel_send(_channels * channels,int64_t id,PyObject * obj)1283 _channel_send(_channels *channels, int64_t id, PyObject *obj)
1284 {
1285     PyInterpreterState *interp = _get_current();
1286     if (interp == NULL) {
1287         return -1;
1288     }
1289 
1290     // Look up the channel.
1291     PyThread_type_lock mutex = NULL;
1292     _PyChannelState *chan = _channels_lookup(channels, id, &mutex);
1293     if (chan == NULL) {
1294         return -1;
1295     }
1296     // Past this point we are responsible for releasing the mutex.
1297 
1298     if (chan->closing != NULL) {
1299         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", id);
1300         PyThread_release_lock(mutex);
1301         return -1;
1302     }
1303 
1304     // Convert the object to cross-interpreter data.
1305     _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
1306     if (data == NULL) {
1307         PyThread_release_lock(mutex);
1308         return -1;
1309     }
1310     if (_PyObject_GetCrossInterpreterData(obj, data) != 0) {
1311         PyThread_release_lock(mutex);
1312         PyMem_Free(data);
1313         return -1;
1314     }
1315 
1316     // Add the data to the channel.
1317     int res = _channel_add(chan, PyInterpreterState_GetID(interp), data);
1318     PyThread_release_lock(mutex);
1319     if (res != 0) {
1320         _PyCrossInterpreterData_Release(data);
1321         PyMem_Free(data);
1322         return -1;
1323     }
1324 
1325     return 0;
1326 }
1327 
1328 static PyObject *
_channel_recv(_channels * channels,int64_t id)1329 _channel_recv(_channels *channels, int64_t id)
1330 {
1331     PyInterpreterState *interp = _get_current();
1332     if (interp == NULL) {
1333         return NULL;
1334     }
1335 
1336     // Look up the channel.
1337     PyThread_type_lock mutex = NULL;
1338     _PyChannelState *chan = _channels_lookup(channels, id, &mutex);
1339     if (chan == NULL) {
1340         return NULL;
1341     }
1342     // Past this point we are responsible for releasing the mutex.
1343 
1344     // Pop off the next item from the channel.
1345     _PyCrossInterpreterData *data = _channel_next(chan, PyInterpreterState_GetID(interp));
1346     PyThread_release_lock(mutex);
1347     if (data == NULL) {
1348         return NULL;
1349     }
1350 
1351     // Convert the data back to an object.
1352     PyObject *obj = _PyCrossInterpreterData_NewObject(data);
1353     _PyCrossInterpreterData_Release(data);
1354     PyMem_Free(data);
1355     if (obj == NULL) {
1356         return NULL;
1357     }
1358 
1359     return obj;
1360 }
1361 
1362 static int
_channel_drop(_channels * channels,int64_t id,int send,int recv)1363 _channel_drop(_channels *channels, int64_t id, int send, int recv)
1364 {
1365     PyInterpreterState *interp = _get_current();
1366     if (interp == NULL) {
1367         return -1;
1368     }
1369 
1370     // Look up the channel.
1371     PyThread_type_lock mutex = NULL;
1372     _PyChannelState *chan = _channels_lookup(channels, id, &mutex);
1373     if (chan == NULL) {
1374         return -1;
1375     }
1376     // Past this point we are responsible for releasing the mutex.
1377 
1378     // Close one or both of the two ends.
1379     int res = _channel_close_interpreter(chan, PyInterpreterState_GetID(interp), send-recv);
1380     PyThread_release_lock(mutex);
1381     return res;
1382 }
1383 
1384 static int
_channel_close(_channels * channels,int64_t id,int end,int force)1385 _channel_close(_channels *channels, int64_t id, int end, int force)
1386 {
1387     return _channels_close(channels, id, NULL, end, force);
1388 }
1389 
1390 static int
_channel_is_associated(_channels * channels,int64_t cid,int64_t interp,int send)1391 _channel_is_associated(_channels *channels, int64_t cid, int64_t interp,
1392                        int send)
1393 {
1394     _PyChannelState *chan = _channels_lookup(channels, cid, NULL);
1395     if (chan == NULL) {
1396         return -1;
1397     } else if (send && chan->closing != NULL) {
1398         PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid);
1399         return -1;
1400     }
1401 
1402     _channelend *end = _channelend_find(send ? chan->ends->send : chan->ends->recv,
1403                                         interp, NULL);
1404 
1405     return (end != NULL && end->open);
1406 }
1407 
1408 /* ChannelID class */
1409 
1410 static PyTypeObject ChannelIDtype;
1411 
1412 typedef struct channelid {
1413     PyObject_HEAD
1414     int64_t id;
1415     int end;
1416     int resolve;
1417     _channels *channels;
1418 } channelid;
1419 
1420 static int
channel_id_converter(PyObject * arg,void * ptr)1421 channel_id_converter(PyObject *arg, void *ptr)
1422 {
1423     int64_t cid;
1424     if (PyObject_TypeCheck(arg, &ChannelIDtype)) {
1425         cid = ((channelid *)arg)->id;
1426     }
1427     else if (PyIndex_Check(arg)) {
1428         cid = PyLong_AsLongLong(arg);
1429         if (cid == -1 && PyErr_Occurred()) {
1430             return 0;
1431         }
1432         if (cid < 0) {
1433             PyErr_Format(PyExc_ValueError,
1434                         "channel ID must be a non-negative int, got %R", arg);
1435             return 0;
1436         }
1437     }
1438     else {
1439         PyErr_Format(PyExc_TypeError,
1440                      "channel ID must be an int, got %.100s",
1441                      Py_TYPE(arg)->tp_name);
1442         return 0;
1443     }
1444     *(int64_t *)ptr = cid;
1445     return 1;
1446 }
1447 
1448 static channelid *
newchannelid(PyTypeObject * cls,int64_t cid,int end,_channels * channels,int force,int resolve)1449 newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels,
1450              int force, int resolve)
1451 {
1452     channelid *self = PyObject_New(channelid, cls);
1453     if (self == NULL) {
1454         return NULL;
1455     }
1456     self->id = cid;
1457     self->end = end;
1458     self->resolve = resolve;
1459     self->channels = channels;
1460 
1461     if (_channels_add_id_object(channels, cid) != 0) {
1462         if (force && PyErr_ExceptionMatches(ChannelNotFoundError)) {
1463             PyErr_Clear();
1464         }
1465         else {
1466             Py_DECREF((PyObject *)self);
1467             return NULL;
1468         }
1469     }
1470 
1471     return self;
1472 }
1473 
1474 static _channels * _global_channels(void);
1475 
1476 static PyObject *
channelid_new(PyTypeObject * cls,PyObject * args,PyObject * kwds)1477 channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
1478 {
1479     static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
1480     int64_t cid;
1481     int send = -1;
1482     int recv = -1;
1483     int force = 0;
1484     int resolve = 0;
1485     if (!PyArg_ParseTupleAndKeywords(args, kwds,
1486                                      "O&|$pppp:ChannelID.__new__", kwlist,
1487                                      channel_id_converter, &cid, &send, &recv, &force, &resolve))
1488         return NULL;
1489 
1490     // Handle "send" and "recv".
1491     if (send == 0 && recv == 0) {
1492         PyErr_SetString(PyExc_ValueError,
1493                         "'send' and 'recv' cannot both be False");
1494         return NULL;
1495     }
1496 
1497     int end = 0;
1498     if (send == 1) {
1499         if (recv == 0 || recv == -1) {
1500             end = CHANNEL_SEND;
1501         }
1502     }
1503     else if (recv == 1) {
1504         end = CHANNEL_RECV;
1505     }
1506 
1507     return (PyObject *)newchannelid(cls, cid, end, _global_channels(),
1508                                     force, resolve);
1509 }
1510 
1511 static void
channelid_dealloc(PyObject * v)1512 channelid_dealloc(PyObject *v)
1513 {
1514     int64_t cid = ((channelid *)v)->id;
1515     _channels *channels = ((channelid *)v)->channels;
1516     Py_TYPE(v)->tp_free(v);
1517 
1518     _channels_drop_id_object(channels, cid);
1519 }
1520 
1521 static PyObject *
channelid_repr(PyObject * self)1522 channelid_repr(PyObject *self)
1523 {
1524     PyTypeObject *type = Py_TYPE(self);
1525     const char *name = _PyType_Name(type);
1526 
1527     channelid *cid = (channelid *)self;
1528     const char *fmt;
1529     if (cid->end == CHANNEL_SEND) {
1530         fmt = "%s(%" PRId64 ", send=True)";
1531     }
1532     else if (cid->end == CHANNEL_RECV) {
1533         fmt = "%s(%" PRId64 ", recv=True)";
1534     }
1535     else {
1536         fmt = "%s(%" PRId64 ")";
1537     }
1538     return PyUnicode_FromFormat(fmt, name, cid->id);
1539 }
1540 
1541 static PyObject *
channelid_str(PyObject * self)1542 channelid_str(PyObject *self)
1543 {
1544     channelid *cid = (channelid *)self;
1545     return PyUnicode_FromFormat("%" PRId64 "", cid->id);
1546 }
1547 
1548 static PyObject *
channelid_int(PyObject * self)1549 channelid_int(PyObject *self)
1550 {
1551     channelid *cid = (channelid *)self;
1552     return PyLong_FromLongLong(cid->id);
1553 }
1554 
1555 static PyNumberMethods channelid_as_number = {
1556      0,                        /* nb_add */
1557      0,                        /* nb_subtract */
1558      0,                        /* nb_multiply */
1559      0,                        /* nb_remainder */
1560      0,                        /* nb_divmod */
1561      0,                        /* nb_power */
1562      0,                        /* nb_negative */
1563      0,                        /* nb_positive */
1564      0,                        /* nb_absolute */
1565      0,                        /* nb_bool */
1566      0,                        /* nb_invert */
1567      0,                        /* nb_lshift */
1568      0,                        /* nb_rshift */
1569      0,                        /* nb_and */
1570      0,                        /* nb_xor */
1571      0,                        /* nb_or */
1572      (unaryfunc)channelid_int, /* nb_int */
1573      0,                        /* nb_reserved */
1574      0,                        /* nb_float */
1575 
1576      0,                        /* nb_inplace_add */
1577      0,                        /* nb_inplace_subtract */
1578      0,                        /* nb_inplace_multiply */
1579      0,                        /* nb_inplace_remainder */
1580      0,                        /* nb_inplace_power */
1581      0,                        /* nb_inplace_lshift */
1582      0,                        /* nb_inplace_rshift */
1583      0,                        /* nb_inplace_and */
1584      0,                        /* nb_inplace_xor */
1585      0,                        /* nb_inplace_or */
1586 
1587      0,                        /* nb_floor_divide */
1588      0,                        /* nb_true_divide */
1589      0,                        /* nb_inplace_floor_divide */
1590      0,                        /* nb_inplace_true_divide */
1591 
1592      (unaryfunc)channelid_int, /* nb_index */
1593 };
1594 
1595 static Py_hash_t
channelid_hash(PyObject * self)1596 channelid_hash(PyObject *self)
1597 {
1598     channelid *cid = (channelid *)self;
1599     PyObject *id = PyLong_FromLongLong(cid->id);
1600     if (id == NULL) {
1601         return -1;
1602     }
1603     Py_hash_t hash = PyObject_Hash(id);
1604     Py_DECREF(id);
1605     return hash;
1606 }
1607 
1608 static PyObject *
channelid_richcompare(PyObject * self,PyObject * other,int op)1609 channelid_richcompare(PyObject *self, PyObject *other, int op)
1610 {
1611     if (op != Py_EQ && op != Py_NE) {
1612         Py_RETURN_NOTIMPLEMENTED;
1613     }
1614 
1615     if (!PyObject_TypeCheck(self, &ChannelIDtype)) {
1616         Py_RETURN_NOTIMPLEMENTED;
1617     }
1618 
1619     channelid *cid = (channelid *)self;
1620     int equal;
1621     if (PyObject_TypeCheck(other, &ChannelIDtype)) {
1622         channelid *othercid = (channelid *)other;
1623         equal = (cid->end == othercid->end) && (cid->id == othercid->id);
1624     }
1625     else if (PyLong_Check(other)) {
1626         /* Fast path */
1627         int overflow;
1628         long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow);
1629         if (othercid == -1 && PyErr_Occurred()) {
1630             return NULL;
1631         }
1632         equal = !overflow && (othercid >= 0) && (cid->id == othercid);
1633     }
1634     else if (PyNumber_Check(other)) {
1635         PyObject *pyid = PyLong_FromLongLong(cid->id);
1636         if (pyid == NULL) {
1637             return NULL;
1638         }
1639         PyObject *res = PyObject_RichCompare(pyid, other, op);
1640         Py_DECREF(pyid);
1641         return res;
1642     }
1643     else {
1644         Py_RETURN_NOTIMPLEMENTED;
1645     }
1646 
1647     if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
1648         Py_RETURN_TRUE;
1649     }
1650     Py_RETURN_FALSE;
1651 }
1652 
1653 static PyObject *
_channel_from_cid(PyObject * cid,int end)1654 _channel_from_cid(PyObject *cid, int end)
1655 {
1656     PyObject *highlevel = PyImport_ImportModule("interpreters");
1657     if (highlevel == NULL) {
1658         PyErr_Clear();
1659         highlevel = PyImport_ImportModule("test.support.interpreters");
1660         if (highlevel == NULL) {
1661             return NULL;
1662         }
1663     }
1664     const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
1665                                                   "SendChannel";
1666     PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
1667     Py_DECREF(highlevel);
1668     if (cls == NULL) {
1669         return NULL;
1670     }
1671     PyObject *chan = PyObject_CallFunctionObjArgs(cls, cid, NULL);
1672     Py_DECREF(cls);
1673     if (chan == NULL) {
1674         return NULL;
1675     }
1676     return chan;
1677 }
1678 
1679 struct _channelid_xid {
1680     int64_t id;
1681     int end;
1682     int resolve;
1683 };
1684 
1685 static PyObject *
_channelid_from_xid(_PyCrossInterpreterData * data)1686 _channelid_from_xid(_PyCrossInterpreterData *data)
1687 {
1688     struct _channelid_xid *xid = (struct _channelid_xid *)data->data;
1689     // Note that we do not preserve the "resolve" flag.
1690     PyObject *cid = (PyObject *)newchannelid(&ChannelIDtype, xid->id, xid->end,
1691                                              _global_channels(), 0, 0);
1692     if (xid->end == 0) {
1693         return cid;
1694     }
1695     if (!xid->resolve) {
1696         return cid;
1697     }
1698 
1699     /* Try returning a high-level channel end but fall back to the ID. */
1700     PyObject *chan = _channel_from_cid(cid, xid->end);
1701     if (chan == NULL) {
1702         PyErr_Clear();
1703         return cid;
1704     }
1705     Py_DECREF(cid);
1706     return chan;
1707 }
1708 
1709 static int
_channelid_shared(PyObject * obj,_PyCrossInterpreterData * data)1710 _channelid_shared(PyObject *obj, _PyCrossInterpreterData *data)
1711 {
1712     struct _channelid_xid *xid = PyMem_NEW(struct _channelid_xid, 1);
1713     if (xid == NULL) {
1714         return -1;
1715     }
1716     xid->id = ((channelid *)obj)->id;
1717     xid->end = ((channelid *)obj)->end;
1718     xid->resolve = ((channelid *)obj)->resolve;
1719 
1720     data->data = xid;
1721     Py_INCREF(obj);
1722     data->obj = obj;
1723     data->new_object = _channelid_from_xid;
1724     data->free = PyMem_Free;
1725     return 0;
1726 }
1727 
1728 static PyObject *
channelid_end(PyObject * self,void * end)1729 channelid_end(PyObject *self, void *end)
1730 {
1731     int force = 1;
1732     channelid *cid = (channelid *)self;
1733     if (end != NULL) {
1734         return (PyObject *)newchannelid(Py_TYPE(self), cid->id, *(int *)end,
1735                                         cid->channels, force, cid->resolve);
1736     }
1737 
1738     if (cid->end == CHANNEL_SEND) {
1739         return PyUnicode_InternFromString("send");
1740     }
1741     if (cid->end == CHANNEL_RECV) {
1742         return PyUnicode_InternFromString("recv");
1743     }
1744     return PyUnicode_InternFromString("both");
1745 }
1746 
1747 static int _channelid_end_send = CHANNEL_SEND;
1748 static int _channelid_end_recv = CHANNEL_RECV;
1749 
1750 static PyGetSetDef channelid_getsets[] = {
1751     {"end", (getter)channelid_end, NULL,
1752      PyDoc_STR("'send', 'recv', or 'both'")},
1753     {"send", (getter)channelid_end, NULL,
1754      PyDoc_STR("the 'send' end of the channel"), &_channelid_end_send},
1755     {"recv", (getter)channelid_end, NULL,
1756      PyDoc_STR("the 'recv' end of the channel"), &_channelid_end_recv},
1757     {NULL}
1758 };
1759 
1760 PyDoc_STRVAR(channelid_doc,
1761 "A channel ID identifies a channel and may be used as an int.");
1762 
1763 static PyTypeObject ChannelIDtype = {
1764     PyVarObject_HEAD_INIT(&PyType_Type, 0)
1765     "_xxsubinterpreters.ChannelID", /* tp_name */
1766     sizeof(channelid),              /* tp_basicsize */
1767     0,                              /* tp_itemsize */
1768     (destructor)channelid_dealloc,  /* tp_dealloc */
1769     0,                              /* tp_vectorcall_offset */
1770     0,                              /* tp_getattr */
1771     0,                              /* tp_setattr */
1772     0,                              /* tp_as_async */
1773     (reprfunc)channelid_repr,       /* tp_repr */
1774     &channelid_as_number,           /* tp_as_number */
1775     0,                              /* tp_as_sequence */
1776     0,                              /* tp_as_mapping */
1777     channelid_hash,                 /* tp_hash */
1778     0,                              /* tp_call */
1779     (reprfunc)channelid_str,        /* tp_str */
1780     0,                              /* tp_getattro */
1781     0,                              /* tp_setattro */
1782     0,                              /* tp_as_buffer */
1783     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
1784     channelid_doc,                  /* tp_doc */
1785     0,                              /* tp_traverse */
1786     0,                              /* tp_clear */
1787     channelid_richcompare,          /* tp_richcompare */
1788     0,                              /* tp_weaklistoffset */
1789     0,                              /* tp_iter */
1790     0,                              /* tp_iternext */
1791     0,                              /* tp_methods */
1792     0,                              /* tp_members */
1793     channelid_getsets,              /* tp_getset */
1794     0,                              /* tp_base */
1795     0,                              /* tp_dict */
1796     0,                              /* tp_descr_get */
1797     0,                              /* tp_descr_set */
1798     0,                              /* tp_dictoffset */
1799     0,                              /* tp_init */
1800     0,                              /* tp_alloc */
1801     // Note that we do not set tp_new to channelid_new.  Instead we
1802     // set it to NULL, meaning it cannot be instantiated from Python
1803     // code.  We do this because there is a strong relationship between
1804     // channel IDs and the channel lifecycle, so this limitation avoids
1805     // related complications.
1806     NULL,                           /* tp_new */
1807 };
1808 
1809 
1810 /* interpreter-specific code ************************************************/
1811 
1812 static PyObject * RunFailedError = NULL;
1813 
1814 static int
interp_exceptions_init(PyObject * ns)1815 interp_exceptions_init(PyObject *ns)
1816 {
1817     // XXX Move the exceptions into per-module memory?
1818 
1819     if (RunFailedError == NULL) {
1820         // An uncaught exception came out of interp_run_string().
1821         RunFailedError = PyErr_NewException("_xxsubinterpreters.RunFailedError",
1822                                             PyExc_RuntimeError, NULL);
1823         if (RunFailedError == NULL) {
1824             return -1;
1825         }
1826         if (PyDict_SetItemString(ns, "RunFailedError", RunFailedError) != 0) {
1827             return -1;
1828         }
1829     }
1830 
1831     return 0;
1832 }
1833 
1834 static int
_is_running(PyInterpreterState * interp)1835 _is_running(PyInterpreterState *interp)
1836 {
1837     PyThreadState *tstate = PyInterpreterState_ThreadHead(interp);
1838     if (PyThreadState_Next(tstate) != NULL) {
1839         PyErr_SetString(PyExc_RuntimeError,
1840                         "interpreter has more than one thread");
1841         return -1;
1842     }
1843 
1844     assert(!PyErr_Occurred());
1845     PyFrameObject *frame = PyThreadState_GetFrame(tstate);
1846     if (frame == NULL) {
1847         return 0;
1848     }
1849 
1850     int executing = (int)(frame->f_executing);
1851     Py_DECREF(frame);
1852 
1853     return executing;
1854 }
1855 
1856 static int
_ensure_not_running(PyInterpreterState * interp)1857 _ensure_not_running(PyInterpreterState *interp)
1858 {
1859     int is_running = _is_running(interp);
1860     if (is_running < 0) {
1861         return -1;
1862     }
1863     if (is_running) {
1864         PyErr_Format(PyExc_RuntimeError, "interpreter already running");
1865         return -1;
1866     }
1867     return 0;
1868 }
1869 
1870 static int
_run_script(PyInterpreterState * interp,const char * codestr,_sharedns * shared,_sharedexception ** exc)1871 _run_script(PyInterpreterState *interp, const char *codestr,
1872             _sharedns *shared, _sharedexception **exc)
1873 {
1874     PyObject *exctype = NULL;
1875     PyObject *excval = NULL;
1876     PyObject *tb = NULL;
1877 
1878     PyObject *main_mod = _PyInterpreterState_GetMainModule(interp);
1879     if (main_mod == NULL) {
1880         goto error;
1881     }
1882     PyObject *ns = PyModule_GetDict(main_mod);  // borrowed
1883     Py_DECREF(main_mod);
1884     if (ns == NULL) {
1885         goto error;
1886     }
1887     Py_INCREF(ns);
1888 
1889     // Apply the cross-interpreter data.
1890     if (shared != NULL) {
1891         if (_sharedns_apply(shared, ns) != 0) {
1892             Py_DECREF(ns);
1893             goto error;
1894         }
1895     }
1896 
1897     // Run the string (see PyRun_SimpleStringFlags).
1898     PyObject *result = PyRun_StringFlags(codestr, Py_file_input, ns, ns, NULL);
1899     Py_DECREF(ns);
1900     if (result == NULL) {
1901         goto error;
1902     }
1903     else {
1904         Py_DECREF(result);  // We throw away the result.
1905     }
1906 
1907     *exc = NULL;
1908     return 0;
1909 
1910 error:
1911     PyErr_Fetch(&exctype, &excval, &tb);
1912 
1913     _sharedexception *sharedexc = _sharedexception_bind(exctype, excval, tb);
1914     Py_XDECREF(exctype);
1915     Py_XDECREF(excval);
1916     Py_XDECREF(tb);
1917     if (sharedexc == NULL) {
1918         fprintf(stderr, "RunFailedError: script raised an uncaught exception");
1919         PyErr_Clear();
1920         sharedexc = NULL;
1921     }
1922     else {
1923         assert(!PyErr_Occurred());
1924     }
1925     *exc = sharedexc;
1926     return -1;
1927 }
1928 
1929 static int
_run_script_in_interpreter(PyInterpreterState * interp,const char * codestr,PyObject * shareables)1930 _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr,
1931                            PyObject *shareables)
1932 {
1933     if (_ensure_not_running(interp) < 0) {
1934         return -1;
1935     }
1936 
1937     _sharedns *shared = _get_shared_ns(shareables);
1938     if (shared == NULL && PyErr_Occurred()) {
1939         return -1;
1940     }
1941 
1942     // Switch to interpreter.
1943     PyThreadState *save_tstate = NULL;
1944     if (interp != PyInterpreterState_Get()) {
1945         // XXX Using the "head" thread isn't strictly correct.
1946         PyThreadState *tstate = PyInterpreterState_ThreadHead(interp);
1947         // XXX Possible GILState issues?
1948         save_tstate = PyThreadState_Swap(tstate);
1949     }
1950 
1951     // Run the script.
1952     _sharedexception *exc = NULL;
1953     int result = _run_script(interp, codestr, shared, &exc);
1954 
1955     // Switch back.
1956     if (save_tstate != NULL) {
1957         PyThreadState_Swap(save_tstate);
1958     }
1959 
1960     // Propagate any exception out to the caller.
1961     if (exc != NULL) {
1962         _sharedexception_apply(exc, RunFailedError);
1963         _sharedexception_free(exc);
1964     }
1965     else if (result != 0) {
1966         // We were unable to allocate a shared exception.
1967         PyErr_NoMemory();
1968     }
1969 
1970     if (shared != NULL) {
1971         _sharedns_free(shared);
1972     }
1973 
1974     return result;
1975 }
1976 
1977 
1978 /* module level code ********************************************************/
1979 
1980 /* globals is the process-global state for the module.  It holds all
1981    the data that we need to share between interpreters, so it cannot
1982    hold PyObject values. */
1983 static struct globals {
1984     _channels channels;
1985 } _globals = {{0}};
1986 
1987 static int
_init_globals(void)1988 _init_globals(void)
1989 {
1990     if (_channels_init(&_globals.channels) != 0) {
1991         return -1;
1992     }
1993     return 0;
1994 }
1995 
1996 static _channels *
_global_channels(void)1997 _global_channels(void) {
1998     return &_globals.channels;
1999 }
2000 
2001 static PyObject *
interp_create(PyObject * self,PyObject * args,PyObject * kwds)2002 interp_create(PyObject *self, PyObject *args, PyObject *kwds)
2003 {
2004 
2005     static char *kwlist[] = {"isolated", NULL};
2006     int isolated = 1;
2007     if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$i:create", kwlist,
2008                                      &isolated)) {
2009         return NULL;
2010     }
2011 
2012     // Create and initialize the new interpreter.
2013     PyThreadState *save_tstate = PyThreadState_Get();
2014     // XXX Possible GILState issues?
2015     PyThreadState *tstate = _Py_NewInterpreter(isolated);
2016     PyThreadState_Swap(save_tstate);
2017     if (tstate == NULL) {
2018         /* Since no new thread state was created, there is no exception to
2019            propagate; raise a fresh one after swapping in the old thread
2020            state. */
2021         PyErr_SetString(PyExc_RuntimeError, "interpreter creation failed");
2022         return NULL;
2023     }
2024     PyInterpreterState *interp = PyThreadState_GetInterpreter(tstate);
2025     PyObject *idobj = _PyInterpreterState_GetIDObject(interp);
2026     if (idobj == NULL) {
2027         // XXX Possible GILState issues?
2028         save_tstate = PyThreadState_Swap(tstate);
2029         Py_EndInterpreter(tstate);
2030         PyThreadState_Swap(save_tstate);
2031         return NULL;
2032     }
2033     _PyInterpreterState_RequireIDRef(interp, 1);
2034     return idobj;
2035 }
2036 
2037 PyDoc_STRVAR(create_doc,
2038 "create() -> ID\n\
2039 \n\
2040 Create a new interpreter and return a unique generated ID.");
2041 
2042 
2043 static PyObject *
interp_destroy(PyObject * self,PyObject * args,PyObject * kwds)2044 interp_destroy(PyObject *self, PyObject *args, PyObject *kwds)
2045 {
2046     static char *kwlist[] = {"id", NULL};
2047     PyObject *id;
2048     // XXX Use "L" for id?
2049     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2050                                      "O:destroy", kwlist, &id)) {
2051         return NULL;
2052     }
2053 
2054     // Look up the interpreter.
2055     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
2056     if (interp == NULL) {
2057         return NULL;
2058     }
2059 
2060     // Ensure we don't try to destroy the current interpreter.
2061     PyInterpreterState *current = _get_current();
2062     if (current == NULL) {
2063         return NULL;
2064     }
2065     if (interp == current) {
2066         PyErr_SetString(PyExc_RuntimeError,
2067                         "cannot destroy the current interpreter");
2068         return NULL;
2069     }
2070 
2071     // Ensure the interpreter isn't running.
2072     /* XXX We *could* support destroying a running interpreter but
2073        aren't going to worry about it for now. */
2074     if (_ensure_not_running(interp) < 0) {
2075         return NULL;
2076     }
2077 
2078     // Destroy the interpreter.
2079     PyThreadState *tstate = PyInterpreterState_ThreadHead(interp);
2080     // XXX Possible GILState issues?
2081     PyThreadState *save_tstate = PyThreadState_Swap(tstate);
2082     Py_EndInterpreter(tstate);
2083     PyThreadState_Swap(save_tstate);
2084 
2085     Py_RETURN_NONE;
2086 }
2087 
2088 PyDoc_STRVAR(destroy_doc,
2089 "destroy(id)\n\
2090 \n\
2091 Destroy the identified interpreter.\n\
2092 \n\
2093 Attempting to destroy the current interpreter results in a RuntimeError.\n\
2094 So does an unrecognized ID.");
2095 
2096 
2097 static PyObject *
interp_list_all(PyObject * self,PyObject * Py_UNUSED (ignored))2098 interp_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
2099 {
2100     PyObject *ids, *id;
2101     PyInterpreterState *interp;
2102 
2103     ids = PyList_New(0);
2104     if (ids == NULL) {
2105         return NULL;
2106     }
2107 
2108     interp = PyInterpreterState_Head();
2109     while (interp != NULL) {
2110         id = _PyInterpreterState_GetIDObject(interp);
2111         if (id == NULL) {
2112             Py_DECREF(ids);
2113             return NULL;
2114         }
2115         // insert at front of list
2116         int res = PyList_Insert(ids, 0, id);
2117         Py_DECREF(id);
2118         if (res < 0) {
2119             Py_DECREF(ids);
2120             return NULL;
2121         }
2122 
2123         interp = PyInterpreterState_Next(interp);
2124     }
2125 
2126     return ids;
2127 }
2128 
2129 PyDoc_STRVAR(list_all_doc,
2130 "list_all() -> [ID]\n\
2131 \n\
2132 Return a list containing the ID of every existing interpreter.");
2133 
2134 
2135 static PyObject *
interp_get_current(PyObject * self,PyObject * Py_UNUSED (ignored))2136 interp_get_current(PyObject *self, PyObject *Py_UNUSED(ignored))
2137 {
2138     PyInterpreterState *interp =_get_current();
2139     if (interp == NULL) {
2140         return NULL;
2141     }
2142     return _PyInterpreterState_GetIDObject(interp);
2143 }
2144 
2145 PyDoc_STRVAR(get_current_doc,
2146 "get_current() -> ID\n\
2147 \n\
2148 Return the ID of current interpreter.");
2149 
2150 
2151 static PyObject *
interp_get_main(PyObject * self,PyObject * Py_UNUSED (ignored))2152 interp_get_main(PyObject *self, PyObject *Py_UNUSED(ignored))
2153 {
2154     // Currently, 0 is always the main interpreter.
2155     int64_t id = 0;
2156     return _PyInterpreterID_New(id);
2157 }
2158 
2159 PyDoc_STRVAR(get_main_doc,
2160 "get_main() -> ID\n\
2161 \n\
2162 Return the ID of main interpreter.");
2163 
2164 
2165 static PyObject *
interp_run_string(PyObject * self,PyObject * args,PyObject * kwds)2166 interp_run_string(PyObject *self, PyObject *args, PyObject *kwds)
2167 {
2168     static char *kwlist[] = {"id", "script", "shared", NULL};
2169     PyObject *id, *code;
2170     PyObject *shared = NULL;
2171     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2172                                      "OU|O:run_string", kwlist,
2173                                      &id, &code, &shared)) {
2174         return NULL;
2175     }
2176 
2177     // Look up the interpreter.
2178     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
2179     if (interp == NULL) {
2180         return NULL;
2181     }
2182 
2183     // Extract code.
2184     Py_ssize_t size;
2185     const char *codestr = PyUnicode_AsUTF8AndSize(code, &size);
2186     if (codestr == NULL) {
2187         return NULL;
2188     }
2189     if (strlen(codestr) != (size_t)size) {
2190         PyErr_SetString(PyExc_ValueError,
2191                         "source code string cannot contain null bytes");
2192         return NULL;
2193     }
2194 
2195     // Run the code in the interpreter.
2196     if (_run_script_in_interpreter(interp, codestr, shared) != 0) {
2197         return NULL;
2198     }
2199     Py_RETURN_NONE;
2200 }
2201 
2202 PyDoc_STRVAR(run_string_doc,
2203 "run_string(id, script, shared)\n\
2204 \n\
2205 Execute the provided string in the identified interpreter.\n\
2206 \n\
2207 See PyRun_SimpleStrings.");
2208 
2209 
2210 static PyObject *
object_is_shareable(PyObject * self,PyObject * args,PyObject * kwds)2211 object_is_shareable(PyObject *self, PyObject *args, PyObject *kwds)
2212 {
2213     static char *kwlist[] = {"obj", NULL};
2214     PyObject *obj;
2215     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2216                                      "O:is_shareable", kwlist, &obj)) {
2217         return NULL;
2218     }
2219 
2220     if (_PyObject_CheckCrossInterpreterData(obj) == 0) {
2221         Py_RETURN_TRUE;
2222     }
2223     PyErr_Clear();
2224     Py_RETURN_FALSE;
2225 }
2226 
2227 PyDoc_STRVAR(is_shareable_doc,
2228 "is_shareable(obj) -> bool\n\
2229 \n\
2230 Return True if the object's data may be shared between interpreters and\n\
2231 False otherwise.");
2232 
2233 
2234 static PyObject *
interp_is_running(PyObject * self,PyObject * args,PyObject * kwds)2235 interp_is_running(PyObject *self, PyObject *args, PyObject *kwds)
2236 {
2237     static char *kwlist[] = {"id", NULL};
2238     PyObject *id;
2239     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2240                                      "O:is_running", kwlist, &id)) {
2241         return NULL;
2242     }
2243 
2244     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
2245     if (interp == NULL) {
2246         return NULL;
2247     }
2248     int is_running = _is_running(interp);
2249     if (is_running < 0) {
2250         return NULL;
2251     }
2252     if (is_running) {
2253         Py_RETURN_TRUE;
2254     }
2255     Py_RETURN_FALSE;
2256 }
2257 
2258 PyDoc_STRVAR(is_running_doc,
2259 "is_running(id) -> bool\n\
2260 \n\
2261 Return whether or not the identified interpreter is running.");
2262 
2263 static PyObject *
channel_create(PyObject * self,PyObject * Py_UNUSED (ignored))2264 channel_create(PyObject *self, PyObject *Py_UNUSED(ignored))
2265 {
2266     int64_t cid = _channel_create(&_globals.channels);
2267     if (cid < 0) {
2268         return NULL;
2269     }
2270     PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, cid, 0,
2271                                             &_globals.channels, 0, 0);
2272     if (id == NULL) {
2273         if (_channel_destroy(&_globals.channels, cid) != 0) {
2274             // XXX issue a warning?
2275         }
2276         return NULL;
2277     }
2278     assert(((channelid *)id)->channels != NULL);
2279     return id;
2280 }
2281 
2282 PyDoc_STRVAR(channel_create_doc,
2283 "channel_create() -> cid\n\
2284 \n\
2285 Create a new cross-interpreter channel and return a unique generated ID.");
2286 
2287 static PyObject *
channel_destroy(PyObject * self,PyObject * args,PyObject * kwds)2288 channel_destroy(PyObject *self, PyObject *args, PyObject *kwds)
2289 {
2290     static char *kwlist[] = {"cid", NULL};
2291     int64_t cid;
2292     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist,
2293                                      channel_id_converter, &cid)) {
2294         return NULL;
2295     }
2296 
2297     if (_channel_destroy(&_globals.channels, cid) != 0) {
2298         return NULL;
2299     }
2300     Py_RETURN_NONE;
2301 }
2302 
2303 PyDoc_STRVAR(channel_destroy_doc,
2304 "channel_destroy(cid)\n\
2305 \n\
2306 Close and finalize the channel.  Afterward attempts to use the channel\n\
2307 will behave as though it never existed.");
2308 
2309 static PyObject *
channel_list_all(PyObject * self,PyObject * Py_UNUSED (ignored))2310 channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
2311 {
2312     int64_t count = 0;
2313     int64_t *cids = _channels_list_all(&_globals.channels, &count);
2314     if (cids == NULL) {
2315         if (count == 0) {
2316             return PyList_New(0);
2317         }
2318         return NULL;
2319     }
2320     PyObject *ids = PyList_New((Py_ssize_t)count);
2321     if (ids == NULL) {
2322         goto finally;
2323     }
2324     int64_t *cur = cids;
2325     for (int64_t i=0; i < count; cur++, i++) {
2326         PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cur, 0,
2327                                                 &_globals.channels, 0, 0);
2328         if (id == NULL) {
2329             Py_DECREF(ids);
2330             ids = NULL;
2331             break;
2332         }
2333         PyList_SET_ITEM(ids, i, id);
2334     }
2335 
2336 finally:
2337     PyMem_Free(cids);
2338     return ids;
2339 }
2340 
2341 PyDoc_STRVAR(channel_list_all_doc,
2342 "channel_list_all() -> [cid]\n\
2343 \n\
2344 Return the list of all IDs for active channels.");
2345 
2346 static PyObject *
channel_list_interpreters(PyObject * self,PyObject * args,PyObject * kwds)2347 channel_list_interpreters(PyObject *self, PyObject *args, PyObject *kwds)
2348 {
2349     static char *kwlist[] = {"cid", "send", NULL};
2350     int64_t cid;            /* Channel ID */
2351     int send = 0;           /* Send or receive end? */
2352     int64_t id;
2353     PyObject *ids, *id_obj;
2354     PyInterpreterState *interp;
2355 
2356     if (!PyArg_ParseTupleAndKeywords(
2357             args, kwds, "O&$p:channel_list_interpreters",
2358             kwlist, channel_id_converter, &cid, &send)) {
2359         return NULL;
2360     }
2361 
2362     ids = PyList_New(0);
2363     if (ids == NULL) {
2364         goto except;
2365     }
2366 
2367     interp = PyInterpreterState_Head();
2368     while (interp != NULL) {
2369         id = PyInterpreterState_GetID(interp);
2370         assert(id >= 0);
2371         int res = _channel_is_associated(&_globals.channels, cid, id, send);
2372         if (res < 0) {
2373             goto except;
2374         }
2375         if (res) {
2376             id_obj = _PyInterpreterState_GetIDObject(interp);
2377             if (id_obj == NULL) {
2378                 goto except;
2379             }
2380             res = PyList_Insert(ids, 0, id_obj);
2381             Py_DECREF(id_obj);
2382             if (res < 0) {
2383                 goto except;
2384             }
2385         }
2386         interp = PyInterpreterState_Next(interp);
2387     }
2388 
2389     goto finally;
2390 
2391 except:
2392     Py_XDECREF(ids);
2393     ids = NULL;
2394 
2395 finally:
2396     return ids;
2397 }
2398 
2399 PyDoc_STRVAR(channel_list_interpreters_doc,
2400 "channel_list_interpreters(cid, *, send) -> [id]\n\
2401 \n\
2402 Return the list of all interpreter IDs associated with an end of the channel.\n\
2403 \n\
2404 The 'send' argument should be a boolean indicating whether to use the send or\n\
2405 receive end.");
2406 
2407 
2408 static PyObject *
channel_send(PyObject * self,PyObject * args,PyObject * kwds)2409 channel_send(PyObject *self, PyObject *args, PyObject *kwds)
2410 {
2411     static char *kwlist[] = {"cid", "obj", NULL};
2412     int64_t cid;
2413     PyObject *obj;
2414     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
2415                                      channel_id_converter, &cid, &obj)) {
2416         return NULL;
2417     }
2418 
2419     if (_channel_send(&_globals.channels, cid, obj) != 0) {
2420         return NULL;
2421     }
2422     Py_RETURN_NONE;
2423 }
2424 
2425 PyDoc_STRVAR(channel_send_doc,
2426 "channel_send(cid, obj)\n\
2427 \n\
2428 Add the object's data to the channel's queue.");
2429 
2430 static PyObject *
channel_recv(PyObject * self,PyObject * args,PyObject * kwds)2431 channel_recv(PyObject *self, PyObject *args, PyObject *kwds)
2432 {
2433     static char *kwlist[] = {"cid", "default", NULL};
2434     int64_t cid;
2435     PyObject *dflt = NULL;
2436     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O:channel_recv", kwlist,
2437                                      channel_id_converter, &cid, &dflt)) {
2438         return NULL;
2439     }
2440     Py_XINCREF(dflt);
2441 
2442     PyObject *obj = _channel_recv(&_globals.channels, cid);
2443     if (obj != NULL) {
2444         Py_XDECREF(dflt);
2445         return obj;
2446     } else if (PyErr_Occurred()) {
2447         Py_XDECREF(dflt);
2448         return NULL;
2449     } else if (dflt != NULL) {
2450         return dflt;
2451     } else {
2452         PyErr_Format(ChannelEmptyError, "channel %" PRId64 " is empty", cid);
2453         return NULL;
2454     }
2455 }
2456 
2457 PyDoc_STRVAR(channel_recv_doc,
2458 "channel_recv(cid, [default]) -> obj\n\
2459 \n\
2460 Return a new object from the data at the front of the channel's queue.\n\
2461 \n\
2462 If there is nothing to receive then raise ChannelEmptyError, unless\n\
2463 a default value is provided.  In that case return it.");
2464 
2465 static PyObject *
channel_close(PyObject * self,PyObject * args,PyObject * kwds)2466 channel_close(PyObject *self, PyObject *args, PyObject *kwds)
2467 {
2468     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
2469     int64_t cid;
2470     int send = 0;
2471     int recv = 0;
2472     int force = 0;
2473     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2474                                      "O&|$ppp:channel_close", kwlist,
2475                                      channel_id_converter, &cid, &send, &recv, &force)) {
2476         return NULL;
2477     }
2478 
2479     if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
2480         return NULL;
2481     }
2482     Py_RETURN_NONE;
2483 }
2484 
2485 PyDoc_STRVAR(channel_close_doc,
2486 "channel_close(cid, *, send=None, recv=None, force=False)\n\
2487 \n\
2488 Close the channel for all interpreters.\n\
2489 \n\
2490 If the channel is empty then the keyword args are ignored and both\n\
2491 ends are immediately closed.  Otherwise, if 'force' is True then\n\
2492 all queued items are released and both ends are immediately\n\
2493 closed.\n\
2494 \n\
2495 If the channel is not empty *and* 'force' is False then following\n\
2496 happens:\n\
2497 \n\
2498  * recv is True (regardless of send):\n\
2499    - raise ChannelNotEmptyError\n\
2500  * recv is None and send is None:\n\
2501    - raise ChannelNotEmptyError\n\
2502  * send is True and recv is not True:\n\
2503    - fully close the 'send' end\n\
2504    - close the 'recv' end to interpreters not already receiving\n\
2505    - fully close it once empty\n\
2506 \n\
2507 Closing an already closed channel results in a ChannelClosedError.\n\
2508 \n\
2509 Once the channel's ID has no more ref counts in any interpreter\n\
2510 the channel will be destroyed.");
2511 
2512 static PyObject *
channel_release(PyObject * self,PyObject * args,PyObject * kwds)2513 channel_release(PyObject *self, PyObject *args, PyObject *kwds)
2514 {
2515     // Note that only the current interpreter is affected.
2516     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
2517     int64_t cid;
2518     int send = 0;
2519     int recv = 0;
2520     int force = 0;
2521     if (!PyArg_ParseTupleAndKeywords(args, kwds,
2522                                      "O&|$ppp:channel_release", kwlist,
2523                                      channel_id_converter, &cid, &send, &recv, &force)) {
2524         return NULL;
2525     }
2526     if (send == 0 && recv == 0) {
2527         send = 1;
2528         recv = 1;
2529     }
2530 
2531     // XXX Handle force is True.
2532     // XXX Fix implicit release.
2533 
2534     if (_channel_drop(&_globals.channels, cid, send, recv) != 0) {
2535         return NULL;
2536     }
2537     Py_RETURN_NONE;
2538 }
2539 
2540 PyDoc_STRVAR(channel_release_doc,
2541 "channel_release(cid, *, send=None, recv=None, force=True)\n\
2542 \n\
2543 Close the channel for the current interpreter.  'send' and 'recv'\n\
2544 (bool) may be used to indicate the ends to close.  By default both\n\
2545 ends are closed.  Closing an already closed end is a noop.");
2546 
2547 static PyObject *
channel__channel_id(PyObject * self,PyObject * args,PyObject * kwds)2548 channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
2549 {
2550     return channelid_new(&ChannelIDtype, args, kwds);
2551 }
2552 
2553 static PyMethodDef module_functions[] = {
2554     {"create",                    (PyCFunction)(void(*)(void))interp_create,
2555      METH_VARARGS | METH_KEYWORDS, create_doc},
2556     {"destroy",                   (PyCFunction)(void(*)(void))interp_destroy,
2557      METH_VARARGS | METH_KEYWORDS, destroy_doc},
2558     {"list_all",                  interp_list_all,
2559      METH_NOARGS, list_all_doc},
2560     {"get_current",               interp_get_current,
2561      METH_NOARGS, get_current_doc},
2562     {"get_main",                  interp_get_main,
2563      METH_NOARGS, get_main_doc},
2564     {"is_running",                (PyCFunction)(void(*)(void))interp_is_running,
2565      METH_VARARGS | METH_KEYWORDS, is_running_doc},
2566     {"run_string",                (PyCFunction)(void(*)(void))interp_run_string,
2567      METH_VARARGS | METH_KEYWORDS, run_string_doc},
2568 
2569     {"is_shareable",              (PyCFunction)(void(*)(void))object_is_shareable,
2570      METH_VARARGS | METH_KEYWORDS, is_shareable_doc},
2571 
2572     {"channel_create",            channel_create,
2573      METH_NOARGS, channel_create_doc},
2574     {"channel_destroy",           (PyCFunction)(void(*)(void))channel_destroy,
2575      METH_VARARGS | METH_KEYWORDS, channel_destroy_doc},
2576     {"channel_list_all",          channel_list_all,
2577      METH_NOARGS, channel_list_all_doc},
2578     {"channel_list_interpreters", (PyCFunction)(void(*)(void))channel_list_interpreters,
2579      METH_VARARGS | METH_KEYWORDS, channel_list_interpreters_doc},
2580     {"channel_send",              (PyCFunction)(void(*)(void))channel_send,
2581      METH_VARARGS | METH_KEYWORDS, channel_send_doc},
2582     {"channel_recv",              (PyCFunction)(void(*)(void))channel_recv,
2583      METH_VARARGS | METH_KEYWORDS, channel_recv_doc},
2584     {"channel_close",             (PyCFunction)(void(*)(void))channel_close,
2585      METH_VARARGS | METH_KEYWORDS, channel_close_doc},
2586     {"channel_release",           (PyCFunction)(void(*)(void))channel_release,
2587      METH_VARARGS | METH_KEYWORDS, channel_release_doc},
2588     {"_channel_id",               (PyCFunction)(void(*)(void))channel__channel_id,
2589      METH_VARARGS | METH_KEYWORDS, NULL},
2590 
2591     {NULL,                        NULL}           /* sentinel */
2592 };
2593 
2594 
2595 /* initialization function */
2596 
2597 PyDoc_STRVAR(module_doc,
2598 "This module provides primitive operations to manage Python interpreters.\n\
2599 The 'interpreters' module provides a more convenient interface.");
2600 
2601 static struct PyModuleDef interpretersmodule = {
2602     PyModuleDef_HEAD_INIT,
2603     "_xxsubinterpreters",  /* m_name */
2604     module_doc,            /* m_doc */
2605     -1,                    /* m_size */
2606     module_functions,      /* m_methods */
2607     NULL,                  /* m_slots */
2608     NULL,                  /* m_traverse */
2609     NULL,                  /* m_clear */
2610     NULL                   /* m_free */
2611 };
2612 
2613 
2614 PyMODINIT_FUNC
PyInit__xxsubinterpreters(void)2615 PyInit__xxsubinterpreters(void)
2616 {
2617     if (_init_globals() != 0) {
2618         return NULL;
2619     }
2620 
2621     /* Initialize types */
2622     if (PyType_Ready(&ChannelIDtype) != 0) {
2623         return NULL;
2624     }
2625 
2626     /* Create the module */
2627     PyObject *module = PyModule_Create(&interpretersmodule);
2628     if (module == NULL) {
2629         return NULL;
2630     }
2631 
2632     /* Add exception types */
2633     PyObject *ns = PyModule_GetDict(module);  // borrowed
2634     if (interp_exceptions_init(ns) != 0) {
2635         return NULL;
2636     }
2637     if (channel_exceptions_init(ns) != 0) {
2638         return NULL;
2639     }
2640 
2641     /* Add other types */
2642     Py_INCREF(&ChannelIDtype);
2643     if (PyDict_SetItemString(ns, "ChannelID", (PyObject *)&ChannelIDtype) != 0) {
2644         return NULL;
2645     }
2646     Py_INCREF(&_PyInterpreterID_Type);
2647     if (PyDict_SetItemString(ns, "InterpreterID", (PyObject *)&_PyInterpreterID_Type) != 0) {
2648         return NULL;
2649     }
2650 
2651     if (_PyCrossInterpreterData_RegisterClass(&ChannelIDtype, _channelid_shared)) {
2652         return NULL;
2653     }
2654 
2655     return module;
2656 }
2657