1#lang typed/racket/base
2
3(require racket/vector
4         "../unsafe.rkt"
5         "array-struct.rkt"
6         "array-broadcast.rkt"
7         "utils.rkt")
8
9(provide (all-defined-out))
10
11;; ===================================================================================================
12;; Arbitrary transforms
13
14(: array-transform (All (A) ((Array A) In-Indexes (Indexes -> In-Indexes) -> (Array A))))
15(define (array-transform arr new-ds idx-fun)
16  (define old-ds (array-shape arr))
17  (define old-f (unsafe-array-proc arr))
18  (build-array
19   new-ds (λ: ([js : Indexes])
20            (old-f (check-array-indexes 'array-transform old-ds (idx-fun js))))))
21
22(: unsafe-array-transform (All (A) ((Array A) Indexes (Indexes -> Indexes) -> (Array A))))
23(define (unsafe-array-transform arr new-ds idx-fun)
24  (define old-f (unsafe-array-proc arr))
25  (unsafe-build-array new-ds (λ: ([js : Indexes]) (old-f (idx-fun js)))))
26
27;; ===================================================================================================
28;; Back permutation and swap
29
30(: array-axis-permute (All (A) ((Array A) (Listof Integer) -> (Array A))))
31(define (array-axis-permute arr perm)
32  (define ds (array-shape arr))
33  (let-values ([(ds perm) (apply-permutation
34                           perm ds (λ () (raise-argument-error 'array-axis-permute "permutation"
35                                                               1 arr perm)))])
36    (define dims (vector-length ds))
37    (define old-js (make-thread-local-indexes dims))
38    (array-default-strict
39     (unsafe-array-transform
40      arr ds
41      (λ: ([js : Indexes])
42        (let ([old-js  (old-js)])
43          (let: loop : Indexes ([i : Nonnegative-Fixnum  0])
44            (cond [(i . < . dims)  (unsafe-vector-set! old-js
45                                                       (unsafe-vector-ref perm i)
46                                                       (unsafe-vector-ref js i))
47                                   (loop (+ i 1))]
48                  [else  old-js]))))))))
49
50(: array-axis-swap (All (A) ((Array A) Integer Integer -> (Array A))))
51(define (array-axis-swap arr i0 i1)
52  (define ds (array-shape arr))
53  (define dims (vector-length ds))
54  (cond [(or (i0 . < . 0) (i0 . >= . dims))
55         (raise-argument-error 'array-transpose (format "Index < ~a" dims) 1 arr i0 i1)]
56        [(or (i1 . < . 0) (i1 . >= . dims))
57         (raise-argument-error 'array-transpose (format "Index < ~a" dims) 2 arr i0 i1)]
58        [(= i0 i1)  arr]
59        [else
60         (define new-ds (vector-copy-all ds))
61         (define j0 (unsafe-vector-ref new-ds i0))
62         (define j1 (unsafe-vector-ref new-ds i1))
63         (unsafe-vector-set! new-ds i0 j1)
64         (unsafe-vector-set! new-ds i1 j0)
65         (define proc (unsafe-array-proc arr))
66         (array-default-strict
67          (unsafe-build-array
68           new-ds (λ: ([js : Indexes])
69                    (define j0 (unsafe-vector-ref js i0))
70                    (define j1 (unsafe-vector-ref js i1))
71                    (unsafe-vector-set! js i0 j1)
72                    (unsafe-vector-set! js i1 j0)
73                    (define v (proc js))
74                    (unsafe-vector-set! js i0 j0)
75                    (unsafe-vector-set! js i1 j1)
76                    v)))]))
77
78;; ===================================================================================================
79;; Adding/removing axes
80
81(: array-axis-insert (All (A) (case-> ((Array A) Integer -> (Array A))
82                                      ((Array A) Integer Integer -> (Array A)))))
83(define (array-axis-insert arr k [dk 1])
84  (define ds (array-shape arr))
85  (define dims (vector-length ds))
86  (cond [(or (k . < . 0) (k . > . dims))
87         (raise-argument-error 'array-axis-insert (format "Index <= ~a" dims) 1 arr k dk)]
88        [(not (index? dk))
89         (raise-argument-error 'array-axis-insert "Index" 2 arr k dk)]
90        [else
91         (define new-ds (unsafe-vector-insert ds k dk))
92         (define proc (unsafe-array-proc arr))
93         (array-default-strict
94          (unsafe-build-array new-ds (λ: ([js : Indexes]) (proc (unsafe-vector-remove js k)))))]))
95
96(: array-axis-ref (All (A) ((Array A) Integer Integer -> (Array A))))
97(define (array-axis-ref arr k jk)
98  (define ds (array-shape arr))
99  (define dims (vector-length ds))
100  (cond [(or (k . < . 0) (k . >= . dims))
101         (raise-argument-error 'array-axis-ref (format "Index < ~a" dims) 1 arr k jk)]
102        [(or (jk . < . 0) (jk . >= . (unsafe-vector-ref ds k)))
103         (raise-argument-error 'array-axis-ref (format "Index < ~a" (unsafe-vector-ref ds k))
104                           2 arr k jk)]
105        [else
106         (define new-ds (unsafe-vector-remove ds k))
107         (define proc (unsafe-array-proc arr))
108         (array-default-strict
109          (unsafe-build-array new-ds (λ: ([js : Indexes]) (proc (unsafe-vector-insert js k jk)))))]))
110
111;; ===================================================================================================
112;; Reshape
113
114(: array-reshape (All (A) ((Array A) In-Indexes -> (Array A))))
115(define (array-reshape arr ds)
116  (let ([ds  (check-array-shape
117              ds (λ () (raise-argument-error 'array-reshape "(Vectorof Index)" 1 arr ds)))])
118    (define size (array-size arr))
119    (unless (= size (array-shape-size ds))
120      (raise-argument-error 'array-reshape (format "(Vectorof Index) with product ~a" size) 1 arr ds))
121    (define old-ds (array-shape arr))
122    (cond [(equal? ds old-ds)  arr]
123          [else
124           (define old-dims (vector-length old-ds))
125           (define g (unsafe-array-proc arr))
126           (define old-js (make-thread-local-indexes old-dims))
127           (array-default-strict
128            (unsafe-build-array
129             ds (λ: ([js : Indexes])
130                  (let ([old-js  (old-js)])
131                    (define j (unsafe-array-index->value-index ds js))
132                    (unsafe-value-index->array-index! old-ds j old-js)
133                    (g old-js)))))])))
134
135(: array-flatten (All (A) ((Array A) -> (Array A))))
136(define (array-flatten arr)
137  (define size (array-size arr))
138  (define: ds : Indexes (vector size))
139  (define old-ds (array-shape arr))
140  (cond [(equal? ds old-ds)  arr]
141        [else
142         (define old-dims (vector-length old-ds))
143         (define g (unsafe-array-proc arr))
144         (define old-js (make-thread-local-indexes old-dims))
145         (array-default-strict
146          (unsafe-build-array
147           ds (λ: ([js : Indexes])
148                (let ([old-js  (old-js)])
149                  (define j (unsafe-vector-ref js 0))
150                  (unsafe-value-index->array-index! old-ds j old-js)
151                  (g old-js)))))]))
152
153;; ===================================================================================================
154;; Append
155
156(: array-broadcast-for-append (All (A) ((Listof (Array A))
157                                        Integer -> (Values (Listof (Array A))
158                                                           (Listof Index)))))
159(define (array-broadcast-for-append arrs k)
160  (define dss (map (λ: ([arr : (Array A)]) (array-shape arr)) arrs))
161  (define dims (apply max (map vector-length dss)))
162  (cond [(not (index? dims))  (error 'array-broadcast-for-append "can't happen")]
163        [(or (k . < . 0) (k . >= . dims))
164         (raise-argument-error 'array-append* (format "Index < ~a" dims) k)]
165        [else
166         (let* ([dss  (map (λ: ([ds : Indexes])
167                             (define dms (vector-length ds))
168                             (vector-append ((inst make-vector Index) (- dims dms) 1) ds))
169                           dss)]
170                [dks  (map (λ: ([ds : Indexes]) (unsafe-vector-ref ds k)) dss)]
171                [dss  (map (λ: ([ds : Indexes]) (unsafe-vector-remove ds k)) dss)]
172                [ds   (array-shape-broadcast dss)]
173                [dss  (map (λ: ([dk : Index]) (unsafe-vector-insert ds k dk)) dks)])
174           (define new-arrs
175             (map (λ: ([arr : (Array A)] [ds : Indexes]) (array-broadcast arr ds)) arrs dss))
176           (values new-arrs dks))]))
177
178(: array-append* (All (A) (case-> ((Listof (Array A)) -> (Array A))
179                                  ((Listof (Array A)) Integer -> (Array A)))))
180(define (array-append* arrs [k 0])
181  (when (null? arrs) (raise-argument-error 'array-append* "nonempty (Listof (Array A))" arrs))
182  (let-values ([(arrs dks)  (array-broadcast-for-append arrs k)])
183    (define new-dk (apply + dks))
184    (cond
185      [(not (index? new-dk))  (error 'array-append* "resulting axis is too large (not an Index)")]
186      [else
187       (define dss (map (λ: ([arr : (Array A)]) (array-shape arr)) arrs))
188       (define new-ds (vector-copy-all (car dss)))
189       (unsafe-vector-set! new-ds k new-dk)
190       ;; Make two mappings:
191       ;; 1. old-procs : new array index -> old array procedure
192       ;; 2. old-jks :   new array index -> old array index
193       (define old-procs (make-vector new-dk (unsafe-array-proc (car arrs))))
194       (define: old-jks : Indexes (make-vector new-dk 0))
195       (let arrs-loop ([arrs arrs] [dks dks] [#{jk : Nonnegative-Fixnum} 0])
196         (unless (null? arrs)
197           (define arr (car arrs))
198           (define proc (unsafe-array-proc arr))
199           (define dk (car dks))
200           (let i-loop ([#{i : Nonnegative-Fixnum} 0] [#{jk : Nonnegative-Fixnum} jk])
201             (cond [(i . < . dk)  (unsafe-vector-set! old-procs jk proc)
202                                  (unsafe-vector-set! old-jks jk i)
203                                  (i-loop (+ i 1) (unsafe-fx+ jk 1))]
204                   [else  (arrs-loop (cdr arrs) (cdr dks) jk)]))))
205       (array-default-strict
206        (unsafe-build-array
207         new-ds (λ: ([js : Indexes])
208                  (define jk (unsafe-vector-ref js k))
209                  (unsafe-vector-set! js k (unsafe-vector-ref old-jks jk))
210                  (define v ((unsafe-vector-ref old-procs jk) js))
211                  (unsafe-vector-set! js k jk)
212                  v)))])))
213