1{-# LANGUAGE CPP, MultiParamTypeClasses, FlexibleContexts, BangPatterns, TypeFamilies, ScopedTypeVariables #-}
2-- |
3-- Module      : Data.Vector.Generic.Mutable
4-- Copyright   : (c) Roman Leshchinskiy 2008-2010
5-- License     : BSD-style
6--
7-- Maintainer  : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8-- Stability   : experimental
9-- Portability : non-portable
10--
11-- Generic interface to mutable vectors
12--
13
14module Data.Vector.Generic.Mutable (
15  -- * Class of mutable vector types
16  MVector(..),
17
18  -- * Accessors
19
20  -- ** Length information
21  length, null,
22
23  -- ** Extracting subvectors
24  slice, init, tail, take, drop, splitAt,
25  unsafeSlice, unsafeInit, unsafeTail, unsafeTake, unsafeDrop,
26
27  -- ** Overlapping
28  overlaps,
29
30  -- * Construction
31
32  -- ** Initialisation
33  new, unsafeNew, replicate, replicateM, clone,
34
35  -- ** Growing
36  grow, unsafeGrow,
37  growFront, unsafeGrowFront,
38
39  -- ** Restricting memory usage
40  clear,
41
42  -- * Accessing individual elements
43  read, write, modify, swap, exchange,
44  unsafeRead, unsafeWrite, unsafeModify, unsafeSwap, unsafeExchange,
45
46  -- * Modifying vectors
47  nextPermutation,
48
49  -- ** Filling and copying
50  set, copy, move, unsafeCopy, unsafeMove,
51
52  -- * Internal operations
53  mstream, mstreamR,
54  unstream, unstreamR, vunstream,
55  munstream, munstreamR,
56  transform, transformR,
57  fill, fillR,
58  unsafeAccum, accum, unsafeUpdate, update, reverse,
59  unstablePartition, unstablePartitionBundle, partitionBundle,
60  partitionWithBundle
61) where
62
63import           Data.Vector.Generic.Mutable.Base
64import qualified Data.Vector.Generic.Base as V
65
66import qualified Data.Vector.Fusion.Bundle      as Bundle
67import           Data.Vector.Fusion.Bundle      ( Bundle, MBundle, Chunk(..) )
68import qualified Data.Vector.Fusion.Bundle.Monadic as MBundle
69import           Data.Vector.Fusion.Stream.Monadic ( Stream )
70import qualified Data.Vector.Fusion.Stream.Monadic as Stream
71import           Data.Vector.Fusion.Bundle.Size
72import           Data.Vector.Fusion.Util        ( delay_inline )
73
74import Control.Monad.Primitive ( PrimMonad, PrimState )
75
76import Prelude hiding ( length, null, replicate, reverse, map, read,
77                        take, drop, splitAt, init, tail )
78
79#include "vector.h"
80
81{-
82type family Immutable (v :: * -> * -> *) :: * -> *
83
84-- | Class of mutable vectors parametrised with a primitive state token.
85--
86class MBundle.Pointer u a => MVector v a where
87  -- | Length of the mutable vector. This method should not be
88  -- called directly, use 'length' instead.
89  basicLength       :: v s a -> Int
90
91  -- | Yield a part of the mutable vector without copying it. This method
92  -- should not be called directly, use 'unsafeSlice' instead.
93  basicUnsafeSlice :: Int  -- ^ starting index
94                   -> Int  -- ^ length of the slice
95                   -> v s a
96                   -> v s a
97
98  -- Check whether two vectors overlap. This method should not be
99  -- called directly, use 'overlaps' instead.
100  basicOverlaps    :: v s a -> v s a -> Bool
101
102  -- | Create a mutable vector of the given length. This method should not be
103  -- called directly, use 'unsafeNew' instead.
104  basicUnsafeNew   :: PrimMonad m => Int -> m (v (PrimState m) a)
105
106  -- | Create a mutable vector of the given length and fill it with an
107  -- initial value. This method should not be called directly, use
108  -- 'replicate' instead.
109  basicUnsafeReplicate :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
110
111  -- | Yield the element at the given position. This method should not be
112  -- called directly, use 'unsafeRead' instead.
113  basicUnsafeRead  :: PrimMonad m => v (PrimState m) a -> Int -> m a
114
115  -- | Replace the element at the given position. This method should not be
116  -- called directly, use 'unsafeWrite' instead.
117  basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
118
119  -- | Reset all elements of the vector to some undefined value, clearing all
120  -- references to external objects. This is usually a noop for unboxed
121  -- vectors. This method should not be called directly, use 'clear' instead.
122  basicClear       :: PrimMonad m => v (PrimState m) a -> m ()
123
124  -- | Set all elements of the vector to the given value. This method should
125  -- not be called directly, use 'set' instead.
126  basicSet         :: PrimMonad m => v (PrimState m) a -> a -> m ()
127
128  basicUnsafeCopyPointer :: PrimMonad m => v (PrimState m) a
129                                        -> Immutable v a
130                                        -> m ()
131
132  -- | Copy a vector. The two vectors may not overlap. This method should not
133  -- be called directly, use 'unsafeCopy' instead.
134  basicUnsafeCopy  :: PrimMonad m => v (PrimState m) a   -- ^ target
135                                  -> v (PrimState m) a   -- ^ source
136                                  -> m ()
137
138  -- | Move the contents of a vector. The two vectors may overlap. This method
139  -- should not be called directly, use 'unsafeMove' instead.
140  basicUnsafeMove  :: PrimMonad m => v (PrimState m) a   -- ^ target
141                                  -> v (PrimState m) a   -- ^ source
142                                  -> m ()
143
144  -- | Grow a vector by the given number of elements. This method should not be
145  -- called directly, use 'unsafeGrow' instead.
146  basicUnsafeGrow  :: PrimMonad m => v (PrimState m) a -> Int
147                                                       -> m (v (PrimState m) a)
148
149  {-# INLINE basicUnsafeReplicate #-}
150  basicUnsafeReplicate n x
151    = do
152        v <- basicUnsafeNew n
153        basicSet v x
154        return v
155
156  {-# INLINE basicClear #-}
157  basicClear _ = return ()
158
159  {-# INLINE basicSet #-}
160  basicSet !v x
161    | n == 0    = return ()
162    | otherwise = do
163                    basicUnsafeWrite v 0 x
164                    do_set 1
165    where
166      !n = basicLength v
167
168      do_set i | 2*i < n = do basicUnsafeCopy (basicUnsafeSlice i i v)
169                                              (basicUnsafeSlice 0 i v)
170                              do_set (2*i)
171               | otherwise = basicUnsafeCopy (basicUnsafeSlice i (n-i) v)
172                                             (basicUnsafeSlice 0 (n-i) v)
173
174  {-# INLINE basicUnsafeCopyPointer #-}
175  basicUnsafeCopyPointer !dst !src = do_copy 0 src
176    where
177      do_copy !i p | Just (x,q) <- MBundle.pget p = do
178                                                      basicUnsafeWrite dst i x
179                                                      do_copy (i+1) q
180                   | otherwise = return ()
181
182  {-# INLINE basicUnsafeCopy #-}
183  basicUnsafeCopy !dst !src = do_copy 0
184    where
185      !n = basicLength src
186
187      do_copy i | i < n = do
188                            x <- basicUnsafeRead src i
189                            basicUnsafeWrite dst i x
190                            do_copy (i+1)
191                | otherwise = return ()
192
193  {-# INLINE basicUnsafeMove #-}
194  basicUnsafeMove !dst !src
195    | basicOverlaps dst src = do
196        srcCopy <- clone src
197        basicUnsafeCopy dst srcCopy
198    | otherwise = basicUnsafeCopy dst src
199
200  {-# INLINE basicUnsafeGrow #-}
201  basicUnsafeGrow v by
202    = do
203        v' <- basicUnsafeNew (n+by)
204        basicUnsafeCopy (basicUnsafeSlice 0 n v') v
205        return v'
206    where
207      n = basicLength v
208-}
209
210-- ------------------
211-- Internal functions
212-- ------------------
213
214unsafeAppend1 :: (PrimMonad m, MVector v a)
215        => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
216{-# INLINE_INNER unsafeAppend1 #-}
217    -- NOTE: The case distinction has to be on the outside because
218    -- GHC creates a join point for the unsafeWrite even when everything
219    -- is inlined. This is bad because with the join point, v isn't getting
220    -- unboxed.
221unsafeAppend1 v i x
222  | i < length v = do
223                     unsafeWrite v i x
224                     return v
225  | otherwise    = do
226                     v' <- enlarge v
227                     INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
228                       $ unsafeWrite v' i x
229                     return v'
230
231unsafePrepend1 :: (PrimMonad m, MVector v a)
232        => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
233{-# INLINE_INNER unsafePrepend1 #-}
234unsafePrepend1 v i x
235  | i /= 0    = do
236                  let i' = i-1
237                  unsafeWrite v i' x
238                  return (v, i')
239  | otherwise = do
240                  (v', j) <- enlargeFront v
241                  let i' = j-1
242                  INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
243                    $ unsafeWrite v' i' x
244                  return (v', i')
245
246mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream m a
247{-# INLINE mstream #-}
248mstream v = v `seq` n `seq` (Stream.unfoldrM get 0)
249  where
250    n = length v
251
252    {-# INLINE_INNER get #-}
253    get i | i < n     = do x <- unsafeRead v i
254                           return $ Just (x, i+1)
255          | otherwise = return $ Nothing
256
257fill :: (PrimMonad m, MVector v a)
258     => v (PrimState m) a -> Stream m a -> m (v (PrimState m) a)
259{-# INLINE fill #-}
260fill v s = v `seq` do
261                     n' <- Stream.foldM put 0 s
262                     return $ unsafeSlice 0 n' v
263  where
264    {-# INLINE_INNER put #-}
265    put i x = do
266                INTERNAL_CHECK(checkIndex) "fill" i (length v)
267                  $ unsafeWrite v i x
268                return (i+1)
269
270transform
271  :: (PrimMonad m, MVector v a)
272  => (Stream m a -> Stream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
273{-# INLINE_FUSED transform #-}
274transform f v = fill v (f (mstream v))
275
276mstreamR :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream m a
277{-# INLINE mstreamR #-}
278mstreamR v = v `seq` n `seq` (Stream.unfoldrM get n)
279  where
280    n = length v
281
282    {-# INLINE_INNER get #-}
283    get i | j >= 0    = do x <- unsafeRead v j
284                           return $ Just (x,j)
285          | otherwise = return Nothing
286      where
287        j = i-1
288
289fillR :: (PrimMonad m, MVector v a)
290      => v (PrimState m) a -> Stream m a -> m (v (PrimState m) a)
291{-# INLINE fillR #-}
292fillR v s = v `seq` do
293                      i <- Stream.foldM put n s
294                      return $ unsafeSlice i (n-i) v
295  where
296    n = length v
297
298    {-# INLINE_INNER put #-}
299    put i x = do
300                unsafeWrite v j x
301                return j
302      where
303        j = i-1
304
305transformR
306  :: (PrimMonad m, MVector v a)
307  => (Stream m a -> Stream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
308{-# INLINE_FUSED transformR #-}
309transformR f v = fillR v (f (mstreamR v))
310
311-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
312-- The vector will grow exponentially if the maximum size of the 'Bundle' is
313-- unknown.
314unstream :: (PrimMonad m, MVector v a)
315         => Bundle u a -> m (v (PrimState m) a)
316-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
317{-# INLINE_FUSED unstream #-}
318unstream s = munstream (Bundle.lift s)
319
320-- | Create a new mutable vector and fill it with elements from the monadic
321-- stream. The vector will grow exponentially if the maximum size of the stream
322-- is unknown.
323munstream :: (PrimMonad m, MVector v a)
324          => MBundle m u a -> m (v (PrimState m) a)
325{-# INLINE_FUSED munstream #-}
326munstream s = case upperBound (MBundle.size s) of
327               Just n  -> munstreamMax     s n
328               Nothing -> munstreamUnknown s
329
330-- FIXME: I can't think of how to prevent GHC from floating out
331-- unstreamUnknown. That is bad because SpecConstr then generates two
332-- specialisations: one for when it is called from unstream (it doesn't know
333-- the shape of the vector) and one for when the vector has grown. To see the
334-- problem simply compile this:
335--
336-- fromList = Data.Vector.Unboxed.unstream . Bundle.fromList
337--
338-- I'm not sure this still applies (19/04/2010)
339
340munstreamMax :: (PrimMonad m, MVector v a)
341             => MBundle m u a -> Int -> m (v (PrimState m) a)
342{-# INLINE munstreamMax #-}
343munstreamMax s n
344  = do
345      v <- INTERNAL_CHECK(checkLength) "munstreamMax" n
346           $ unsafeNew n
347      let put i x = do
348                       INTERNAL_CHECK(checkIndex) "munstreamMax" i n
349                         $ unsafeWrite v i x
350                       return (i+1)
351      n' <- MBundle.foldM' put 0 s
352      return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
353             $ unsafeSlice 0 n' v
354
355munstreamUnknown :: (PrimMonad m, MVector v a)
356                 => MBundle m u a -> m (v (PrimState m) a)
357{-# INLINE munstreamUnknown #-}
358munstreamUnknown s
359  = do
360      v <- unsafeNew 0
361      (v', n) <- MBundle.foldM put (v, 0) s
362      return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
363             $ unsafeSlice 0 n v'
364  where
365    {-# INLINE_INNER put #-}
366    put (v,i) x = do
367                    v' <- unsafeAppend1 v i x
368                    return (v',i+1)
369
370
371
372
373
374
375
376-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
377-- The vector will grow exponentially if the maximum size of the 'Bundle' is
378-- unknown.
379vunstream :: (PrimMonad m, V.Vector v a)
380         => Bundle v a -> m (V.Mutable v (PrimState m) a)
381-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
382{-# INLINE_FUSED vunstream #-}
383vunstream s = vmunstream (Bundle.lift s)
384
385-- | Create a new mutable vector and fill it with elements from the monadic
386-- stream. The vector will grow exponentially if the maximum size of the stream
387-- is unknown.
388vmunstream :: (PrimMonad m, V.Vector v a)
389           => MBundle m v a -> m (V.Mutable v (PrimState m) a)
390{-# INLINE_FUSED vmunstream #-}
391vmunstream s = case upperBound (MBundle.size s) of
392               Just n  -> vmunstreamMax     s n
393               Nothing -> vmunstreamUnknown s
394
395-- FIXME: I can't think of how to prevent GHC from floating out
396-- unstreamUnknown. That is bad because SpecConstr then generates two
397-- specialisations: one for when it is called from unstream (it doesn't know
398-- the shape of the vector) and one for when the vector has grown. To see the
399-- problem simply compile this:
400--
401-- fromList = Data.Vector.Unboxed.unstream . Bundle.fromList
402--
403-- I'm not sure this still applies (19/04/2010)
404
405vmunstreamMax :: (PrimMonad m, V.Vector v a)
406              => MBundle m v a -> Int -> m (V.Mutable v (PrimState m) a)
407{-# INLINE vmunstreamMax #-}
408vmunstreamMax s n
409  = do
410      v <- INTERNAL_CHECK(checkLength) "munstreamMax" n
411           $ unsafeNew n
412      let {-# INLINE_INNER copyChunk #-}
413          copyChunk i (Chunk m f) =
414            INTERNAL_CHECK(checkSlice) "munstreamMax.copyChunk" i m (length v) $ do
415              f (basicUnsafeSlice i m v)
416              return (i+m)
417
418      n' <- Stream.foldlM' copyChunk 0 (MBundle.chunks s)
419      return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
420             $ unsafeSlice 0 n' v
421
422vmunstreamUnknown :: (PrimMonad m, V.Vector v a)
423                 => MBundle m v a -> m (V.Mutable v (PrimState m) a)
424{-# INLINE vmunstreamUnknown #-}
425vmunstreamUnknown s
426  = do
427      v <- unsafeNew 0
428      (v', n) <- Stream.foldlM copyChunk (v,0) (MBundle.chunks s)
429      return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
430             $ unsafeSlice 0 n v'
431  where
432    {-# INLINE_INNER copyChunk #-}
433    copyChunk (v,i) (Chunk n f)
434      = do
435          let j = i+n
436          v' <- if basicLength v < j
437                  then unsafeGrow v (delay_inline max (enlarge_delta v) (j - basicLength v))
438                  else return v
439          INTERNAL_CHECK(checkSlice) "munstreamUnknown.copyChunk" i n (length v')
440            $ f (basicUnsafeSlice i n v')
441          return (v',j)
442
443
444
445
446-- | Create a new mutable vector and fill it with elements from the 'Bundle'
447-- from right to left. The vector will grow exponentially if the maximum size
448-- of the 'Bundle' is unknown.
449unstreamR :: (PrimMonad m, MVector v a)
450          => Bundle u a -> m (v (PrimState m) a)
451-- NOTE: replace INLINE_FUSED by INLINE? (also in unstream)
452{-# INLINE_FUSED unstreamR #-}
453unstreamR s = munstreamR (Bundle.lift s)
454
455-- | Create a new mutable vector and fill it with elements from the monadic
456-- stream from right to left. The vector will grow exponentially if the maximum
457-- size of the stream is unknown.
458munstreamR :: (PrimMonad m, MVector v a)
459           => MBundle m u a -> m (v (PrimState m) a)
460{-# INLINE_FUSED munstreamR #-}
461munstreamR s = case upperBound (MBundle.size s) of
462               Just n  -> munstreamRMax     s n
463               Nothing -> munstreamRUnknown s
464
465munstreamRMax :: (PrimMonad m, MVector v a)
466              => MBundle m u a -> Int -> m (v (PrimState m) a)
467{-# INLINE munstreamRMax #-}
468munstreamRMax s n
469  = do
470      v <- INTERNAL_CHECK(checkLength) "munstreamRMax" n
471           $ unsafeNew n
472      let put i x = do
473                      let i' = i-1
474                      INTERNAL_CHECK(checkIndex) "munstreamRMax" i' n
475                        $ unsafeWrite v i' x
476                      return i'
477      i <- MBundle.foldM' put n s
478      return $ INTERNAL_CHECK(checkSlice) "munstreamRMax" i (n-i) n
479             $ unsafeSlice i (n-i) v
480
481munstreamRUnknown :: (PrimMonad m, MVector v a)
482                  => MBundle m u a -> m (v (PrimState m) a)
483{-# INLINE munstreamRUnknown #-}
484munstreamRUnknown s
485  = do
486      v <- unsafeNew 0
487      (v', i) <- MBundle.foldM put (v, 0) s
488      let n = length v'
489      return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
490             $ unsafeSlice i (n-i) v'
491  where
492    {-# INLINE_INNER put #-}
493    put (v,i) x = unsafePrepend1 v i x
494
495-- Length
496-- ------
497
498-- | Length of the mutable vector.
499length :: MVector v a => v s a -> Int
500{-# INLINE length #-}
501length = basicLength
502
503-- | Check whether the vector is empty
504null :: MVector v a => v s a -> Bool
505{-# INLINE null #-}
506null v = length v == 0
507
508-- Extracting subvectors
509-- ---------------------
510
511-- | Yield a part of the mutable vector without copying it. The vector must
512-- contain at least @i+n@ elements.
513slice :: MVector v a
514      => Int  -- ^ @i@ starting index
515      -> Int  -- ^ @n@ length
516      -> v s a
517      -> v s a
518{-# INLINE slice #-}
519slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
520            $ unsafeSlice i n v
521
522take :: MVector v a => Int -> v s a -> v s a
523{-# INLINE take #-}
524take n v = unsafeSlice 0 (min (max n 0) (length v)) v
525
526drop :: MVector v a => Int -> v s a -> v s a
527{-# INLINE drop #-}
528drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
529  where
530    n' = max n 0
531    m  = length v
532
533{-# INLINE splitAt #-}
534splitAt :: MVector v a => Int -> v s a -> (v s a, v s a)
535splitAt n v = ( unsafeSlice 0 m v
536              , unsafeSlice m (max 0 (len - n')) v
537              )
538    where
539      m   = min n' len
540      n'  = max n 0
541      len = length v
542
543init :: MVector v a => v s a -> v s a
544{-# INLINE init #-}
545init v = slice 0 (length v - 1) v
546
547tail :: MVector v a => v s a -> v s a
548{-# INLINE tail #-}
549tail v = slice 1 (length v - 1) v
550
551-- | Yield a part of the mutable vector without copying it. No bounds checks
552-- are performed.
553unsafeSlice :: MVector v a => Int  -- ^ starting index
554                           -> Int  -- ^ length of the slice
555                           -> v s a
556                           -> v s a
557{-# INLINE unsafeSlice #-}
558unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
559                  $ basicUnsafeSlice i n v
560
561unsafeInit :: MVector v a => v s a -> v s a
562{-# INLINE unsafeInit #-}
563unsafeInit v = unsafeSlice 0 (length v - 1) v
564
565unsafeTail :: MVector v a => v s a -> v s a
566{-# INLINE unsafeTail #-}
567unsafeTail v = unsafeSlice 1 (length v - 1) v
568
569unsafeTake :: MVector v a => Int -> v s a -> v s a
570{-# INLINE unsafeTake #-}
571unsafeTake n v = unsafeSlice 0 n v
572
573unsafeDrop :: MVector v a => Int -> v s a -> v s a
574{-# INLINE unsafeDrop #-}
575unsafeDrop n v = unsafeSlice n (length v - n) v
576
577-- Overlapping
578-- -----------
579
580-- | Check whether two vectors overlap.
581overlaps :: MVector v a => v s a -> v s a -> Bool
582{-# INLINE overlaps #-}
583overlaps = basicOverlaps
584
585-- Initialisation
586-- --------------
587
588-- | Create a mutable vector of the given length.
589new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
590{-# INLINE new #-}
591new n = BOUNDS_CHECK(checkLength) "new" n
592      $ unsafeNew n >>= \v -> basicInitialize v >> return v
593
594-- | Create a mutable vector of the given length. The memory is not initialized.
595unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
596{-# INLINE unsafeNew #-}
597unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
598            $ basicUnsafeNew n
599
600-- | Create a mutable vector of the given length (0 if the length is negative)
601-- and fill it with an initial value.
602replicate :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
603{-# INLINE replicate #-}
604replicate n x = basicUnsafeReplicate (delay_inline max 0 n) x
605
606-- | Create a mutable vector of the given length (0 if the length is negative)
607-- and fill it with values produced by repeatedly executing the monadic action.
608replicateM :: (PrimMonad m, MVector v a) => Int -> m a -> m (v (PrimState m) a)
609{-# INLINE replicateM #-}
610replicateM n m = munstream (MBundle.replicateM n m)
611
612-- | Create a copy of a mutable vector.
613clone :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m (v (PrimState m) a)
614{-# INLINE clone #-}
615clone v = do
616            v' <- unsafeNew (length v)
617            unsafeCopy v' v
618            return v'
619
620-- Growing
621-- -------
622
623-- | Grow a vector by the given number of elements. The number must be
624-- positive.
625grow :: (PrimMonad m, MVector v a)
626                => v (PrimState m) a -> Int -> m (v (PrimState m) a)
627{-# INLINE grow #-}
628grow v by = BOUNDS_CHECK(checkLength) "grow" by
629          $ do vnew <- unsafeGrow v by
630               basicInitialize $ basicUnsafeSlice (length v) by vnew
631               return vnew
632
633growFront :: (PrimMonad m, MVector v a)
634                => v (PrimState m) a -> Int -> m (v (PrimState m) a)
635{-# INLINE growFront #-}
636growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
637               $ do vnew <- unsafeGrowFront v by
638                    basicInitialize $ basicUnsafeSlice 0 by vnew
639                    return vnew
640
641enlarge_delta :: MVector v a => v s a -> Int
642enlarge_delta v = max (length v) 1
643
644-- | Grow a vector logarithmically
645enlarge :: (PrimMonad m, MVector v a)
646                => v (PrimState m) a -> m (v (PrimState m) a)
647{-# INLINE enlarge #-}
648enlarge v = do vnew <- unsafeGrow v by
649               basicInitialize $ basicUnsafeSlice (length v) by vnew
650               return vnew
651  where
652    by = enlarge_delta v
653
654enlargeFront :: (PrimMonad m, MVector v a)
655                => v (PrimState m) a -> m (v (PrimState m) a, Int)
656{-# INLINE enlargeFront #-}
657enlargeFront v = do
658                   v' <- unsafeGrowFront v by
659                   basicInitialize $ basicUnsafeSlice 0 by v'
660                   return (v', by)
661  where
662    by = enlarge_delta v
663
664-- | Grow a vector by the given number of elements. The number must be
665-- positive but this is not checked.
666unsafeGrow :: (PrimMonad m, MVector v a)
667                        => v (PrimState m) a -> Int -> m (v (PrimState m) a)
668{-# INLINE unsafeGrow #-}
669unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
670               $ basicUnsafeGrow v n
671
672unsafeGrowFront :: (PrimMonad m, MVector v a)
673                        => v (PrimState m) a -> Int -> m (v (PrimState m) a)
674{-# INLINE unsafeGrowFront #-}
675unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
676                     $ do
677                         let n = length v
678                         v' <- basicUnsafeNew (by+n)
679                         basicUnsafeCopy (basicUnsafeSlice by n v') v
680                         return v'
681
682-- Restricting memory usage
683-- ------------------------
684
685-- | Reset all elements of the vector to some undefined value, clearing all
686-- references to external objects. This is usually a noop for unboxed vectors.
687clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
688{-# INLINE clear #-}
689clear = basicClear
690
691-- Accessing individual elements
692-- -----------------------------
693
694-- | Yield the element at the given position.
695read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
696{-# INLINE read #-}
697read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
698         $ unsafeRead v i
699
700-- | Replace the element at the given position.
701write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
702{-# INLINE write #-}
703write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
704            $ unsafeWrite v i x
705
706-- | Modify the element at the given position.
707modify :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> a) -> Int -> m ()
708{-# INLINE modify #-}
709modify v f i = BOUNDS_CHECK(checkIndex) "modify" i (length v)
710             $ unsafeModify v f i
711
712-- | Swap the elements at the given positions.
713swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
714{-# INLINE swap #-}
715swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
716           $ BOUNDS_CHECK(checkIndex) "swap" j (length v)
717           $ unsafeSwap v i j
718
719-- | Replace the element at the given position and return the old element.
720exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
721{-# INLINE exchange #-}
722exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
723               $ unsafeExchange v i x
724
725-- | Yield the element at the given position. No bounds checks are performed.
726unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
727{-# INLINE unsafeRead #-}
728unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
729               $ basicUnsafeRead v i
730
731-- | Replace the element at the given position. No bounds checks are performed.
732unsafeWrite :: (PrimMonad m, MVector v a)
733                                => v (PrimState m) a -> Int -> a -> m ()
734{-# INLINE unsafeWrite #-}
735unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
736                  $ basicUnsafeWrite v i x
737
738-- | Modify the element at the given position. No bounds checks are performed.
739unsafeModify :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> a) -> Int -> m ()
740{-# INLINE unsafeModify #-}
741unsafeModify v f i = UNSAFE_CHECK(checkIndex) "unsafeModify" i (length v)
742                   $ basicUnsafeRead v i >>= \x ->
743                     basicUnsafeWrite v i (f x)
744
745-- | Swap the elements at the given positions. No bounds checks are performed.
746unsafeSwap :: (PrimMonad m, MVector v a)
747                => v (PrimState m) a -> Int -> Int -> m ()
748{-# INLINE unsafeSwap #-}
749unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
750                 $ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
751                 $ do
752                     x <- unsafeRead v i
753                     y <- unsafeRead v j
754                     unsafeWrite v i y
755                     unsafeWrite v j x
756
757-- | Replace the element at the given position and return the old element. No
758-- bounds checks are performed.
759unsafeExchange :: (PrimMonad m, MVector v a)
760                                => v (PrimState m) a -> Int -> a -> m a
761{-# INLINE unsafeExchange #-}
762unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
763                     $ do
764                         y <- unsafeRead v i
765                         unsafeWrite v i x
766                         return y
767
768-- Filling and copying
769-- -------------------
770
771-- | Set all elements of the vector to the given value.
772set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
773{-# INLINE set #-}
774set = basicSet
775
776-- | Copy a vector. The two vectors must have the same length and may not
777-- overlap.
778copy :: (PrimMonad m, MVector v a) => v (PrimState m) a   -- ^ target
779                                   -> v (PrimState m) a   -- ^ source
780                                   -> m ()
781{-# INLINE copy #-}
782copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
783                                          (not (dst `overlaps` src))
784             $ BOUNDS_CHECK(check) "copy" "length mismatch"
785                                          (length dst == length src)
786             $ unsafeCopy dst src
787
788-- | Move the contents of a vector. The two vectors must have the same
789-- length.
790--
791-- If the vectors do not overlap, then this is equivalent to 'copy'.
792-- Otherwise, the copying is performed as if the source vector were
793-- copied to a temporary vector and then the temporary vector was copied
794-- to the target vector.
795move :: (PrimMonad m, MVector v a)
796     => v (PrimState m) a   -- ^ target
797     -> v (PrimState m) a   -- ^ source
798     -> m ()
799{-# INLINE move #-}
800move dst src = BOUNDS_CHECK(check) "move" "length mismatch"
801                                          (length dst == length src)
802             $ unsafeMove dst src
803
804-- | Copy a vector. The two vectors must have the same length and may not
805-- overlap. This is not checked.
806unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a   -- ^ target
807                                         -> v (PrimState m) a   -- ^ source
808                                         -> m ()
809{-# INLINE unsafeCopy #-}
810unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
811                                         (length dst == length src)
812                   $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
813                                         (not (dst `overlaps` src))
814                   $ (dst `seq` src `seq` basicUnsafeCopy dst src)
815
816-- | Move the contents of a vector. The two vectors must have the same
817-- length, but this is not checked.
818--
819-- If the vectors do not overlap, then this is equivalent to 'unsafeCopy'.
820-- Otherwise, the copying is performed as if the source vector were
821-- copied to a temporary vector and then the temporary vector was copied
822-- to the target vector.
823unsafeMove :: (PrimMonad m, MVector v a) => v (PrimState m) a   -- ^ target
824                                         -> v (PrimState m) a   -- ^ source
825                                         -> m ()
826{-# INLINE unsafeMove #-}
827unsafeMove dst src = UNSAFE_CHECK(check) "unsafeMove" "length mismatch"
828                                         (length dst == length src)
829                   $ (dst `seq` src `seq` basicUnsafeMove dst src)
830
831-- Permutations
832-- ------------
833
834accum :: (PrimMonad m, MVector v a)
835      => (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
836{-# INLINE accum #-}
837accum f !v s = Bundle.mapM_ upd s
838  where
839    {-# INLINE_INNER upd #-}
840    upd (i,b) = do
841                  a <- BOUNDS_CHECK(checkIndex) "accum" i n
842                     $ unsafeRead v i
843                  unsafeWrite v i (f a b)
844
845    !n = length v
846
847update :: (PrimMonad m, MVector v a)
848                        => v (PrimState m) a -> Bundle u (Int, a) -> m ()
849{-# INLINE update #-}
850update !v s = Bundle.mapM_ upd s
851  where
852    {-# INLINE_INNER upd #-}
853    upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i n
854              $ unsafeWrite v i b
855
856    !n = length v
857
858unsafeAccum :: (PrimMonad m, MVector v a)
859            => (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
860{-# INLINE unsafeAccum #-}
861unsafeAccum f !v s = Bundle.mapM_ upd s
862  where
863    {-# INLINE_INNER upd #-}
864    upd (i,b) = do
865                  a <- UNSAFE_CHECK(checkIndex) "accum" i n
866                     $ unsafeRead v i
867                  unsafeWrite v i (f a b)
868
869    !n = length v
870
871unsafeUpdate :: (PrimMonad m, MVector v a)
872                        => v (PrimState m) a -> Bundle u (Int, a) -> m ()
873{-# INLINE unsafeUpdate #-}
874unsafeUpdate !v s = Bundle.mapM_ upd s
875  where
876    {-# INLINE_INNER upd #-}
877    upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i n
878                  $ unsafeWrite v i b
879
880    !n = length v
881
882reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
883{-# INLINE reverse #-}
884reverse !v = reverse_loop 0 (length v - 1)
885  where
886    reverse_loop i j | i < j = do
887                                 unsafeSwap v i j
888                                 reverse_loop (i + 1) (j - 1)
889    reverse_loop _ _ = return ()
890
891unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
892                  => (a -> Bool) -> v (PrimState m) a -> m Int
893{-# INLINE unstablePartition #-}
894unstablePartition f !v = from_left 0 (length v)
895  where
896    -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
897    -- from_right
898    from_left :: Int -> Int -> m Int
899    from_left i j
900      | i == j    = return i
901      | otherwise = do
902                      x <- unsafeRead v i
903                      if f x
904                        then from_left (i+1) j
905                        else from_right i (j-1)
906
907    from_right :: Int -> Int -> m Int
908    from_right i j
909      | i == j    = return i
910      | otherwise = do
911                      x <- unsafeRead v j
912                      if f x
913                        then do
914                               y <- unsafeRead v i
915                               unsafeWrite v i x
916                               unsafeWrite v j y
917                               from_left (i+1) j
918                        else from_right i (j-1)
919
920unstablePartitionBundle :: (PrimMonad m, MVector v a)
921        => (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
922{-# INLINE unstablePartitionBundle #-}
923unstablePartitionBundle f s
924  = case upperBound (Bundle.size s) of
925      Just n  -> unstablePartitionMax f s n
926      Nothing -> partitionUnknown f s
927
928unstablePartitionMax :: (PrimMonad m, MVector v a)
929        => (a -> Bool) -> Bundle u a -> Int
930        -> m (v (PrimState m) a, v (PrimState m) a)
931{-# INLINE unstablePartitionMax #-}
932unstablePartitionMax f s n
933  = do
934      v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
935           $ unsafeNew n
936      let {-# INLINE_INNER put #-}
937          put (i, j) x
938            | f x       = do
939                            unsafeWrite v i x
940                            return (i+1, j)
941            | otherwise = do
942                            unsafeWrite v (j-1) x
943                            return (i, j-1)
944
945      (i,j) <- Bundle.foldM' put (0, n) s
946      return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
947
948partitionBundle :: (PrimMonad m, MVector v a)
949        => (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
950{-# INLINE partitionBundle #-}
951partitionBundle f s
952  = case upperBound (Bundle.size s) of
953      Just n  -> partitionMax f s n
954      Nothing -> partitionUnknown f s
955
956partitionMax :: (PrimMonad m, MVector v a)
957  => (a -> Bool) -> Bundle u a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
958{-# INLINE partitionMax #-}
959partitionMax f s n
960  = do
961      v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
962         $ unsafeNew n
963
964      let {-# INLINE_INNER put #-}
965          put (i,j) x
966            | f x       = do
967                            unsafeWrite v i x
968                            return (i+1,j)
969
970            | otherwise = let j' = j-1 in
971                          do
972                            unsafeWrite v j' x
973                            return (i,j')
974
975      (i,j) <- Bundle.foldM' put (0,n) s
976      INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
977        $ return ()
978      let l = unsafeSlice 0 i v
979          r = unsafeSlice j (n-j) v
980      reverse r
981      return (l,r)
982
983partitionUnknown :: (PrimMonad m, MVector v a)
984        => (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
985{-# INLINE partitionUnknown #-}
986partitionUnknown f s
987  = do
988      v1 <- unsafeNew 0
989      v2 <- unsafeNew 0
990      (v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
991      INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
992        $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
993        $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
994  where
995    -- NOTE: The case distinction has to be on the outside because
996    -- GHC creates a join point for the unsafeWrite even when everything
997    -- is inlined. This is bad because with the join point, v isn't getting
998    -- unboxed.
999    {-# INLINE_INNER put #-}
1000    put (v1, i1, v2, i2) x
1001      | f x       = do
1002                      v1' <- unsafeAppend1 v1 i1 x
1003                      return (v1', i1+1, v2, i2)
1004      | otherwise = do
1005                      v2' <- unsafeAppend1 v2 i2 x
1006                      return (v1, i1, v2', i2+1)
1007
1008
1009partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
1010        => (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
1011{-# INLINE partitionWithBundle #-}
1012partitionWithBundle f s
1013  = case upperBound (Bundle.size s) of
1014      Just n  -> partitionWithMax f s n
1015      Nothing -> partitionWithUnknown f s
1016
1017partitionWithMax :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
1018  => (a -> Either b c) -> Bundle u a -> Int -> m (v (PrimState m) b, v (PrimState m) c)
1019{-# INLINE partitionWithMax #-}
1020partitionWithMax f s n
1021  = do
1022      v1 <- unsafeNew n
1023      v2 <- unsafeNew n
1024      let {-# INLINE_INNER put #-}
1025          put (i1, i2) x = case f x of
1026            Left b -> do
1027              unsafeWrite v1 i1 b
1028              return (i1+1, i2)
1029            Right c -> do
1030              unsafeWrite v2 i2 c
1031              return (i1, i2+1)
1032      (n1, n2) <- Bundle.foldM' put (0, 0) s
1033      INTERNAL_CHECK(checkSlice) "partitionEithersMax" 0 n1 (length v1)
1034        $ INTERNAL_CHECK(checkSlice) "partitionEithersMax" 0 n2 (length v2)
1035        $ return (unsafeSlice 0 n1 v1, unsafeSlice 0 n2 v2)
1036
1037partitionWithUnknown :: forall m v u a b c.
1038     (PrimMonad m, MVector v a, MVector v b, MVector v c)
1039  => (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
1040{-# INLINE partitionWithUnknown #-}
1041partitionWithUnknown f s
1042  = do
1043      v1 <- unsafeNew 0
1044      v2 <- unsafeNew 0
1045      (v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
1046      INTERNAL_CHECK(checkSlice) "partitionEithersUnknown" 0 n1 (length v1')
1047        $ INTERNAL_CHECK(checkSlice) "partitionEithersUnknown" 0 n2 (length v2')
1048        $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
1049  where
1050    put :: (v (PrimState m) b, Int, v (PrimState m) c, Int)
1051        -> a
1052        -> m (v (PrimState m) b, Int, v (PrimState m) c, Int)
1053    {-# INLINE_INNER put #-}
1054    put (v1, i1, v2, i2) x = case f x of
1055      Left b -> do
1056        v1' <- unsafeAppend1 v1 i1 b
1057        return (v1', i1+1, v2, i2)
1058      Right c -> do
1059        v2' <- unsafeAppend1 v2 i2 c
1060        return (v1, i1, v2', i2+1)
1061
1062{-
1063http://en.wikipedia.org/wiki/Permutation#Algorithms_to_generate_permutations
1064
1065The following algorithm generates the next permutation lexicographically after
1066a given permutation. It changes the given permutation in-place.
1067
10681. Find the largest index k such that a[k] < a[k + 1]. If no such index exists,
1069   the permutation is the last permutation.
10702. Find the largest index l greater than k such that a[k] < a[l].
10713. Swap the value of a[k] with that of a[l].
10724. Reverse the sequence from a[k + 1] up to and including the final element a[n]
1073-}
1074
1075-- | Compute the next (lexicographically) permutation of given vector in-place.
1076--   Returns False when input is the last permutation
1077nextPermutation :: (PrimMonad m,Ord e,MVector v e) => v (PrimState m) e -> m Bool
1078nextPermutation v
1079    | dim < 2 = return False
1080    | otherwise = do
1081        val <- unsafeRead v 0
1082        (k,l) <- loop val (-1) 0 val 1
1083        if k < 0
1084         then return False
1085         else unsafeSwap v k l >>
1086              reverse (unsafeSlice (k+1) (dim-k-1) v) >>
1087              return True
1088    where loop !kval !k !l !prev !i
1089              | i == dim = return (k,l)
1090              | otherwise  = do
1091                  cur <- unsafeRead v i
1092                  -- TODO: make tuple unboxed
1093                  let (kval',k') = if prev < cur then (prev,i-1) else (kval,k)
1094                      l' = if kval' < cur then i else l
1095                  loop kval' k' l' cur (i+1)
1096          dim = length v
1097