1#lang typed/racket/base 2 3(require racket/vector 4 "../unsafe.rkt" 5 "array-struct.rkt" 6 "array-broadcast.rkt" 7 "utils.rkt") 8 9(provide (all-defined-out)) 10 11;; =================================================================================================== 12;; Arbitrary transforms 13 14(: array-transform (All (A) ((Array A) In-Indexes (Indexes -> In-Indexes) -> (Array A)))) 15(define (array-transform arr new-ds idx-fun) 16 (define old-ds (array-shape arr)) 17 (define old-f (unsafe-array-proc arr)) 18 (build-array 19 new-ds (λ: ([js : Indexes]) 20 (old-f (check-array-indexes 'array-transform old-ds (idx-fun js)))))) 21 22(: unsafe-array-transform (All (A) ((Array A) Indexes (Indexes -> Indexes) -> (Array A)))) 23(define (unsafe-array-transform arr new-ds idx-fun) 24 (define old-f (unsafe-array-proc arr)) 25 (unsafe-build-array new-ds (λ: ([js : Indexes]) (old-f (idx-fun js))))) 26 27;; =================================================================================================== 28;; Back permutation and swap 29 30(: array-axis-permute (All (A) ((Array A) (Listof Integer) -> (Array A)))) 31(define (array-axis-permute arr perm) 32 (define ds (array-shape arr)) 33 (let-values ([(ds perm) (apply-permutation 34 perm ds (λ () (raise-argument-error 'array-axis-permute "permutation" 35 1 arr perm)))]) 36 (define dims (vector-length ds)) 37 (define old-js (make-thread-local-indexes dims)) 38 (array-default-strict 39 (unsafe-array-transform 40 arr ds 41 (λ: ([js : Indexes]) 42 (let ([old-js (old-js)]) 43 (let: loop : Indexes ([i : Nonnegative-Fixnum 0]) 44 (cond [(i . < . dims) (unsafe-vector-set! old-js 45 (unsafe-vector-ref perm i) 46 (unsafe-vector-ref js i)) 47 (loop (+ i 1))] 48 [else old-js])))))))) 49 50(: array-axis-swap (All (A) ((Array A) Integer Integer -> (Array A)))) 51(define (array-axis-swap arr i0 i1) 52 (define ds (array-shape arr)) 53 (define dims (vector-length ds)) 54 (cond [(or (i0 . < . 0) (i0 . >= . dims)) 55 (raise-argument-error 'array-transpose (format "Index < ~a" dims) 1 arr i0 i1)] 56 [(or (i1 . < . 0) (i1 . >= . dims)) 57 (raise-argument-error 'array-transpose (format "Index < ~a" dims) 2 arr i0 i1)] 58 [(= i0 i1) arr] 59 [else 60 (define new-ds (vector-copy-all ds)) 61 (define j0 (unsafe-vector-ref new-ds i0)) 62 (define j1 (unsafe-vector-ref new-ds i1)) 63 (unsafe-vector-set! new-ds i0 j1) 64 (unsafe-vector-set! new-ds i1 j0) 65 (define proc (unsafe-array-proc arr)) 66 (array-default-strict 67 (unsafe-build-array 68 new-ds (λ: ([js : Indexes]) 69 (define j0 (unsafe-vector-ref js i0)) 70 (define j1 (unsafe-vector-ref js i1)) 71 (unsafe-vector-set! js i0 j1) 72 (unsafe-vector-set! js i1 j0) 73 (define v (proc js)) 74 (unsafe-vector-set! js i0 j0) 75 (unsafe-vector-set! js i1 j1) 76 v)))])) 77 78;; =================================================================================================== 79;; Adding/removing axes 80 81(: array-axis-insert (All (A) (case-> ((Array A) Integer -> (Array A)) 82 ((Array A) Integer Integer -> (Array A))))) 83(define (array-axis-insert arr k [dk 1]) 84 (define ds (array-shape arr)) 85 (define dims (vector-length ds)) 86 (cond [(or (k . < . 0) (k . > . dims)) 87 (raise-argument-error 'array-axis-insert (format "Index <= ~a" dims) 1 arr k dk)] 88 [(not (index? dk)) 89 (raise-argument-error 'array-axis-insert "Index" 2 arr k dk)] 90 [else 91 (define new-ds (unsafe-vector-insert ds k dk)) 92 (define proc (unsafe-array-proc arr)) 93 (array-default-strict 94 (unsafe-build-array new-ds (λ: ([js : Indexes]) (proc (unsafe-vector-remove js k)))))])) 95 96(: array-axis-ref (All (A) ((Array A) Integer Integer -> (Array A)))) 97(define (array-axis-ref arr k jk) 98 (define ds (array-shape arr)) 99 (define dims (vector-length ds)) 100 (cond [(or (k . < . 0) (k . >= . dims)) 101 (raise-argument-error 'array-axis-ref (format "Index < ~a" dims) 1 arr k jk)] 102 [(or (jk . < . 0) (jk . >= . (unsafe-vector-ref ds k))) 103 (raise-argument-error 'array-axis-ref (format "Index < ~a" (unsafe-vector-ref ds k)) 104 2 arr k jk)] 105 [else 106 (define new-ds (unsafe-vector-remove ds k)) 107 (define proc (unsafe-array-proc arr)) 108 (array-default-strict 109 (unsafe-build-array new-ds (λ: ([js : Indexes]) (proc (unsafe-vector-insert js k jk)))))])) 110 111;; =================================================================================================== 112;; Reshape 113 114(: array-reshape (All (A) ((Array A) In-Indexes -> (Array A)))) 115(define (array-reshape arr ds) 116 (let ([ds (check-array-shape 117 ds (λ () (raise-argument-error 'array-reshape "(Vectorof Index)" 1 arr ds)))]) 118 (define size (array-size arr)) 119 (unless (= size (array-shape-size ds)) 120 (raise-argument-error 'array-reshape (format "(Vectorof Index) with product ~a" size) 1 arr ds)) 121 (define old-ds (array-shape arr)) 122 (cond [(equal? ds old-ds) arr] 123 [else 124 (define old-dims (vector-length old-ds)) 125 (define g (unsafe-array-proc arr)) 126 (define old-js (make-thread-local-indexes old-dims)) 127 (array-default-strict 128 (unsafe-build-array 129 ds (λ: ([js : Indexes]) 130 (let ([old-js (old-js)]) 131 (define j (unsafe-array-index->value-index ds js)) 132 (unsafe-value-index->array-index! old-ds j old-js) 133 (g old-js)))))]))) 134 135(: array-flatten (All (A) ((Array A) -> (Array A)))) 136(define (array-flatten arr) 137 (define size (array-size arr)) 138 (define: ds : Indexes (vector size)) 139 (define old-ds (array-shape arr)) 140 (cond [(equal? ds old-ds) arr] 141 [else 142 (define old-dims (vector-length old-ds)) 143 (define g (unsafe-array-proc arr)) 144 (define old-js (make-thread-local-indexes old-dims)) 145 (array-default-strict 146 (unsafe-build-array 147 ds (λ: ([js : Indexes]) 148 (let ([old-js (old-js)]) 149 (define j (unsafe-vector-ref js 0)) 150 (unsafe-value-index->array-index! old-ds j old-js) 151 (g old-js)))))])) 152 153;; =================================================================================================== 154;; Append 155 156(: array-broadcast-for-append (All (A) ((Listof (Array A)) 157 Integer -> (Values (Listof (Array A)) 158 (Listof Index))))) 159(define (array-broadcast-for-append arrs k) 160 (define dss (map (λ: ([arr : (Array A)]) (array-shape arr)) arrs)) 161 (define dims (apply max (map vector-length dss))) 162 (cond [(not (index? dims)) (error 'array-broadcast-for-append "can't happen")] 163 [(or (k . < . 0) (k . >= . dims)) 164 (raise-argument-error 'array-append* (format "Index < ~a" dims) k)] 165 [else 166 (let* ([dss (map (λ: ([ds : Indexes]) 167 (define dms (vector-length ds)) 168 (vector-append ((inst make-vector Index) (- dims dms) 1) ds)) 169 dss)] 170 [dks (map (λ: ([ds : Indexes]) (unsafe-vector-ref ds k)) dss)] 171 [dss (map (λ: ([ds : Indexes]) (unsafe-vector-remove ds k)) dss)] 172 [ds (array-shape-broadcast dss)] 173 [dss (map (λ: ([dk : Index]) (unsafe-vector-insert ds k dk)) dks)]) 174 (define new-arrs 175 (map (λ: ([arr : (Array A)] [ds : Indexes]) (array-broadcast arr ds)) arrs dss)) 176 (values new-arrs dks))])) 177 178(: array-append* (All (A) (case-> ((Listof (Array A)) -> (Array A)) 179 ((Listof (Array A)) Integer -> (Array A))))) 180(define (array-append* arrs [k 0]) 181 (when (null? arrs) (raise-argument-error 'array-append* "nonempty (Listof (Array A))" arrs)) 182 (let-values ([(arrs dks) (array-broadcast-for-append arrs k)]) 183 (define new-dk (apply + dks)) 184 (cond 185 [(not (index? new-dk)) (error 'array-append* "resulting axis is too large (not an Index)")] 186 [else 187 (define dss (map (λ: ([arr : (Array A)]) (array-shape arr)) arrs)) 188 (define new-ds (vector-copy-all (car dss))) 189 (unsafe-vector-set! new-ds k new-dk) 190 ;; Make two mappings: 191 ;; 1. old-procs : new array index -> old array procedure 192 ;; 2. old-jks : new array index -> old array index 193 (define old-procs (make-vector new-dk (unsafe-array-proc (car arrs)))) 194 (define: old-jks : Indexes (make-vector new-dk 0)) 195 (let arrs-loop ([arrs arrs] [dks dks] [#{jk : Nonnegative-Fixnum} 0]) 196 (unless (null? arrs) 197 (define arr (car arrs)) 198 (define proc (unsafe-array-proc arr)) 199 (define dk (car dks)) 200 (let i-loop ([#{i : Nonnegative-Fixnum} 0] [#{jk : Nonnegative-Fixnum} jk]) 201 (cond [(i . < . dk) (unsafe-vector-set! old-procs jk proc) 202 (unsafe-vector-set! old-jks jk i) 203 (i-loop (+ i 1) (unsafe-fx+ jk 1))] 204 [else (arrs-loop (cdr arrs) (cdr dks) jk)])))) 205 (array-default-strict 206 (unsafe-build-array 207 new-ds (λ: ([js : Indexes]) 208 (define jk (unsafe-vector-ref js k)) 209 (unsafe-vector-set! js k (unsafe-vector-ref old-jks jk)) 210 (define v ((unsafe-vector-ref old-procs jk) js)) 211 (unsafe-vector-set! js k jk) 212 v)))]))) 213