1;;; Hy AST walker
2;; Copyright 2020 the authors.
3;; This file is part of Hy, which is free software licensed under the Expat
4;; license. See the LICENSE.
5
6(import [hy [HyExpression HyDict]]
7        [hy.models [HySequence]]
8        [functools [partial]]
9        [importlib [import-module]]
10        [collections [OrderedDict]]
11        [hy.macros [macroexpand :as mexpand]]
12        [hy.compiler [HyASTCompiler]])
13
14(defn walk [inner outer form]
15  "Traverses form, an arbitrary data structure. Applies inner to each
16  element of form, building up a data structure of the same type.
17  Applies outer to the result."
18  (cond
19   [(instance? HyExpression form)
20    (outer (HyExpression (map inner form)))]
21   [(or (instance? HySequence form) (list? form))
22    ((type form) (outer (HyExpression (map inner form))))]
23   [(coll? form)
24    (walk inner outer (list form))]
25   [True (outer form)]))
26
27(defn postwalk [f form]
28  "Performs depth-first, post-order traversal of form. Calls f on each
29  sub-form, uses f's return value in place of the original."
30  (walk (partial postwalk f) f form))
31
32(defn prewalk [f form]
33  "Performs depth-first, pre-order traversal of form. Calls f on each
34  sub-form, uses f's return value in place of the original."
35  (walk (partial prewalk f) identity (f form)))
36
37;; TODO: move to hy.core?
38(defn call? [form]
39  "Checks whether form is a non-empty HyExpression"
40  (and (instance? HyExpression form)
41       form))
42
43(defn macroexpand-all [form &optional module-name]
44  "Recursively performs all possible macroexpansions in form."
45  (setv module (or (and module-name
46                        (import-module module-name))
47                   (calling-module))
48        quote-level 0
49        ast-compiler (HyASTCompiler module))  ; TODO: make nonlocal after dropping Python2
50  (defn traverse [form]
51    (walk expand identity form))
52  (defn expand [form]
53    (nonlocal quote-level)
54    ;; manages quote levels
55    (defn +quote [&optional [x 1]]
56      (nonlocal quote-level)
57      (setv head (first form))
58      (+= quote-level x)
59      (when (neg? quote-level)
60        (raise (TypeError "unquote outside of quasiquote")))
61      (setv res (traverse (cut form 1)))
62      (-= quote-level x)
63      `(~head ~@res))
64    (if (call? form)
65        (cond [quote-level
66               (cond [(in (first form) '[unquote unquote-splice])
67                      (+quote -1)]
68                     [(= (first form) 'quasiquote) (+quote)]
69                     [True (traverse form)])]
70              [(= (first form) 'quote) form]
71              [(= (first form) 'quasiquote) (+quote)]
72              [(= (first form) (HySymbol "require"))
73               (ast-compiler.compile form)
74               (return)]
75              [True (traverse (mexpand form module ast-compiler))])
76        (if (coll? form)
77            (traverse form)
78            form)))
79  (expand form))
80
81;; TODO: move to hy.extra.reserved?
82(import hy)
83(setv special-forms (list (.keys hy.compiler._special-form-compilers)))
84
85
86(defn lambda-list [form]
87  "
88splits a fn argument list into sections based on &-headers.
89
90returns an OrderedDict mapping headers to sublists.
91Arguments without a header are under None.
92"
93  (setv headers ['&optional '&rest '&kwonly '&kwargs]
94        sections (OrderedDict [(, None [])])
95        header None)
96  (for [arg form]
97    (if (in arg headers)
98      (do (setv header arg)
99          (assoc sections header [])
100          ;; Don't use a header more than once. It's the compiler's problem.
101          (.remove headers header))
102      (.append (get sections header) arg)))
103  sections)
104
105
106(defn symbolexpand [form expander
107                    &optional
108                    [protected (frozenset)]
109                    [quote-level 0]]
110  (.expand (SymbolExpander form expander protected quote-level)))
111
112(defclass SymbolExpander[]
113
114  (defn __init__ [self form expander protected quote-level]
115    (setv self.form form
116          self.expander expander
117          self.protected protected
118          self.quote-level quote-level))
119
120  (defn expand-symbols [self form &optional protected quote-level]
121    (if (none? protected)
122        (setv protected self.protected))
123    (if (none? quote-level)
124        (setv quote-level self.quote-level))
125    (symbolexpand form self.expander protected quote-level))
126
127  (defn traverse [self form &optional protected quote-level]
128    (if (none? protected)
129        (setv protected self.protected))
130    (if (none? quote-level)
131        (setv quote-level self.quote-level))
132    (walk (partial symbolexpand
133                   :expander self.expander
134                   :protected protected
135                   :quote-level quote-level)
136          identity
137          form))
138
139  ;; manages quote levels
140  (defn +quote [self &optional [x 1]]
141    `(~(self.head) ~@(self.traverse (self.tail)
142                                    :quote-level (+ self.quote-level x))))
143
144  (defn handle-dot [self]
145    `(. ~(self.expand-symbols (first (self.tail)))
146        ~@(walk (fn [form]
147                  (if (symbol? form)
148                      form  ; don't expand attrs
149                      (self.expand-symbols form)))
150                identity
151                (cut (self.tail)
152                     1))))
153
154  (defn head [self]
155    (first self.form))
156
157  (defn tail [self]
158    (cut self.form 1))
159
160  (defn handle-except [self]
161    (setv tail (self.tail))
162    ;; protect the "as" name binding the exception
163    `(~(self.head) ~@(self.traverse tail (| self.protected
164                                            (if (and tail
165                                                     (-> tail
166                                                         first
167                                                         len
168                                                         (= 2)))
169                                                #{(first (first tail))}
170                                                #{})))))
171  (defn handle-args-list [self]
172    (setv protected #{}
173          argslist [])
174    (for [[header section] (-> self (.tail) first lambda-list .items)]
175      (if header (.append argslist header))
176      (cond [(in header [None '&rest '&kwargs])
177             (.update protected section)
178             (.extend argslist section)]
179            [(in header '[&optional &kwonly])
180             (for [pair section]
181               (cond [(coll? pair)
182                      (.add protected (first pair))
183                      (.append argslist
184                               `[~(first pair)
185                                 ~(self.expand-symbols (second pair))])]
186                     [True
187                      (.add protected pair)
188                      (.append argslist pair)]))]))
189    (, protected argslist))
190
191  (defn handle-fn [self]
192    (setv [protected argslist] (self.handle-args-list))
193    `(~(self.head) ~argslist
194      ~@(self.traverse (cut (self.tail) 1)(| protected self.protected))))
195
196  ;; don't expand symbols in quotations
197  (defn handle-quoted [self]
198    (if (call? self.form)
199        (if (in (self.head) '[unquote unquote-splice]) (self.+quote -1)
200            (= (self.head) 'quasiquote) (self.+quote)
201            (self.handle-coll))
202        (if (coll? self.form)
203            (self.handle-coll)
204            (self.handle-base))))
205
206  ;; convert dotted names to the standard special form
207  (defn convert-dotted-symbol [self]
208    (self.expand-symbols `(. ~@(map HySymbol (.split self.form '.)))))
209
210  (defn expand-symbol [self]
211    (if (not-in self.form self.protected)
212        (self.expander self.form)
213        (self.handle-base)))
214
215  (defn handle-symbol [self]
216    (if (and self.form
217             (not (.startswith self.form '.))
218             (in '. self.form))
219        (self.convert-dotted-symbol)
220        (self.expand-symbol)))
221
222  (defn handle-global [self]
223    (.update self.protected (set (self.tail)))
224    (self.handle-base))
225
226  (defn handle-defclass [self]
227    ;; don't expand the name of the class
228    `(~(self.head) ~(first (self.tail))
229      ~@(self.traverse (cut (self.tail) 1))))
230
231  (defn handle-special-form [self]
232    ;; don't expand other special form symbols in head position
233    `(~(self.head) ~@(self.traverse (self.tail))))
234
235  (defn handle-base [self]
236    self.form)
237
238  (defn handle-coll [self]
239    ;; recursion
240    (self.traverse self.form))
241
242  ;; We have to treat special forms differently.
243  ;; Quotation should suppress symbol expansion,
244  ;; and local bindings should shadow those made by let.
245  (defn handle-call [self]
246    (setv head (first self.form))
247    (if (in head '[fn fn*]) (self.handle-fn)
248        (in head '[import
249                   require
250                   quote
251                   eval-and-compile
252                   eval-when-compile]) (self.handle-base)
253        (= head 'except) (self.handle-except)
254        (= head ".") (self.handle-dot)
255        (= head 'defclass) (self.handle-defclass)
256        (= head 'quasiquote) (self.+quote)
257        ;; must be checked last!
258        (in (mangle head) special-forms) (self.handle-special-form)
259        ;; Not a special form. Traverse it like a coll
260        (self.handle-coll)))
261
262  (defn expand [self]
263    "the main entry point. Call this to do  the expansion"
264    (setv form self.form)
265    (if self.quote-level (self.handle-quoted)
266        (symbol? form) (self.handle-symbol)
267        (call? form) (self.handle-call)
268        (coll? form) (self.handle-coll)
269        ;; recursive base case--it's an atom. Put it back.
270        (self.handle-base))))
271
272(defmacro smacrolet [bindings &optional module-name &rest body]
273  "
274symbol macro let.
275
276Replaces symbols in body, but only where it would be a valid let binding.
277The bindings pairs the target symbol and the expansion form for that symbol.
278"
279  (if (odd? (len bindings))
280      (macro-error bindings "bindings must be paired"))
281  (for [k (cut bindings None None 2)]
282    (if-not (symbol? k)
283            (macro-error k "bind targets must be symbols")
284            (if (in '. k)
285                (macro-error k "binding target may not contain a dot"))))
286  (setv bindings (dict (partition bindings))
287        body (macroexpand-all body (or module-name (calling-module-name))))
288  (symbolexpand `(do ~@body)
289                (fn [symbol]
290                  (.get bindings symbol symbol))))
291
292(defmacro let [bindings &rest body]
293  "
294sets up lexical bindings in its body
295
296Bindings are processed sequentially,
297so you can use the result of an earlier binding in a later one.
298
299Basic assignments (e.g. setv, +=) will update the let binding,
300if they use the name of a let binding.
301
302But assignments via `import` are always hoisted to normal Python scope, and
303likewise, `defclass` will assign the class to the Python scope,
304even if it shares the name of a let binding.
305
306Use `import_module` and `type` (or whatever metaclass) instead,
307if you must avoid this hoisting.
308
309Function arguments can shadow let bindings in their body,
310as can nested let forms.
311"
312  (if (odd? (len bindings))
313      (macro-error bindings "let bindings must be paired"))
314  (setv g!let (gensym 'let)
315        replacements (OrderedDict)
316        keys []
317        values [])
318  (defn expander [symbol]
319    (.get replacements symbol symbol))
320  (for [[k v] (partition bindings)]
321    (if-not (symbol? k)
322            (macro-error k "bind targets must be symbols")
323            (if (in '. k)
324                (macro-error k "binding target may not contain a dot")))
325    (.append values (symbolexpand (macroexpand-all v &name)
326                                  expander))
327    (.append keys `(get ~g!let ~(name k)))
328    (assoc replacements k (last keys)))
329  `(do
330     (setv ~g!let {}
331           ~@(interleave keys values))
332     ~@(symbolexpand (macroexpand-all body &name)
333                     expander)))
334
335;; (defmacro macrolet [])
336