1-- | This module performs basic inlining of known functions
2module Language.PureScript.CoreImp.Optimizer.Inliner
3  ( inlineVariables
4  , inlineCommonValues
5  , inlineCommonOperators
6  , inlineFnComposition
7  , inlineUnsafeCoerce
8  , inlineUnsafePartial
9  , etaConvert
10  , unThunk
11  , evaluateIifes
12  ) where
13
14import Prelude.Compat
15
16import Control.Monad.Supply.Class (MonadSupply, freshName)
17
18import Data.Either (rights)
19import Data.Maybe (fromMaybe)
20import Data.String (IsString, fromString)
21import Data.Text (Text)
22import qualified Data.Text as T
23
24import Language.PureScript.PSString (PSString)
25import Language.PureScript.CoreImp.AST
26import Language.PureScript.CoreImp.Optimizer.Common
27import Language.PureScript.AST (SourceSpan(..))
28import qualified Language.PureScript.Constants.Prelude as C
29import qualified Language.PureScript.Constants.Prim as C
30
31-- TODO: Potential bug:
32-- Shouldn't just inline this case: { var x = 0; x.toFixed(10); }
33-- Needs to be: { 0..toFixed(10); }
34-- Probably needs to be fixed in pretty-printer instead.
35shouldInline :: AST -> Bool
36shouldInline (Var _ _) = True
37shouldInline (NumericLiteral _ _) = True
38shouldInline (StringLiteral _ _) = True
39shouldInline (BooleanLiteral _ _) = True
40shouldInline (Indexer _ index val) = shouldInline index && shouldInline val
41shouldInline _ = False
42
43etaConvert :: AST -> AST
44etaConvert = everywhere convert
45  where
46  convert :: AST -> AST
47  convert (Block ss [Return _ (App _ (Function _ Nothing idents block@(Block _ body)) args)])
48    | all shouldInline args &&
49      not (any ((`isRebound` block) . Var Nothing) idents) &&
50      not (any (`isRebound` block) args)
51      = Block ss (map (replaceIdents (zip idents args)) body)
52  convert (Function _ Nothing [] (Block _ [Return _ (App _ fn [])])) = fn
53  convert js = js
54
55unThunk :: AST -> AST
56unThunk = everywhere convert
57  where
58  convert :: AST -> AST
59  convert (Block ss []) = Block ss []
60  convert (Block ss jss) =
61    case last jss of
62      Return _ (App _ (Function _ Nothing [] (Block _ body)) []) -> Block ss $ init jss ++ body
63      _ -> Block ss jss
64  convert js = js
65
66evaluateIifes :: AST -> AST
67evaluateIifes = everywhere convert
68  where
69  convert :: AST -> AST
70  convert (App _ (Function _ Nothing [] (Block _ [Return _ ret])) []) = ret
71  convert (App _ (Function _ Nothing idents (Block _ [Return ss ret])) [])
72    | not (any (`isReassigned` ret) idents) = replaceIdents (map (, Var ss C.undefined) idents) ret
73  convert js = js
74
75inlineVariables :: AST -> AST
76inlineVariables = everywhere $ removeFromBlock go
77  where
78  go :: [AST] -> [AST]
79  go [] = []
80  go (VariableIntroduction _ var (Just js) : sts)
81    | shouldInline js && not (any (isReassigned var) sts) && not (any (isRebound js) sts) && not (any (isUpdated var) sts) =
82      go (map (replaceIdent var js) sts)
83  go (s:sts) = s : go sts
84
85inlineCommonValues :: AST -> AST
86inlineCommonValues = everywhere convert
87  where
88  convert :: AST -> AST
89  convert (App ss fn [dict])
90    | isDict' [semiringNumber, semiringInt] dict && isDict fnZero fn = NumericLiteral ss (Left 0)
91    | isDict' [semiringNumber, semiringInt] dict && isDict fnOne fn = NumericLiteral ss (Left 1)
92    | isDict boundedBoolean dict && isDict fnBottom fn = BooleanLiteral ss False
93    | isDict boundedBoolean dict && isDict fnTop fn = BooleanLiteral ss True
94  convert (App ss (App _ fn [dict]) [x])
95    | isDict ringInt dict && isDict fnNegate fn = Binary ss BitwiseOr (Unary ss Negate x) (NumericLiteral ss (Left 0))
96  convert (App ss (App _ (App _ fn [dict]) [x]) [y])
97    | isDict semiringInt dict && isDict fnAdd fn = intOp ss Add x y
98    | isDict semiringInt dict && isDict fnMultiply fn = intOp ss Multiply x y
99    | isDict ringInt dict && isDict fnSubtract fn = intOp ss Subtract x y
100  convert other = other
101  fnZero = (C.dataSemiring, C.zero)
102  fnOne = (C.dataSemiring, C.one)
103  fnBottom = (C.dataBounded, C.bottom)
104  fnTop = (C.dataBounded, C.top)
105  fnAdd = (C.dataSemiring, C.add)
106  fnMultiply = (C.dataSemiring, C.mul)
107  fnSubtract = (C.dataRing, C.sub)
108  fnNegate = (C.dataRing, C.negate)
109  intOp ss op x y = Binary ss BitwiseOr (Binary ss op x y) (NumericLiteral ss (Left 0))
110
111inlineCommonOperators :: AST -> AST
112inlineCommonOperators = everywhereTopDown $ applyAll $
113  [ binary semiringNumber opAdd Add
114  , binary semiringNumber opMul Multiply
115
116  , binary ringNumber opSub Subtract
117  , unary  ringNumber opNegate Negate
118
119  , binary euclideanRingNumber opDiv Divide
120
121  , binary eqNumber opEq EqualTo
122  , binary eqNumber opNotEq NotEqualTo
123  , binary eqInt opEq EqualTo
124  , binary eqInt opNotEq NotEqualTo
125  , binary eqString opEq EqualTo
126  , binary eqString opNotEq NotEqualTo
127  , binary eqChar opEq EqualTo
128  , binary eqChar opNotEq NotEqualTo
129  , binary eqBoolean opEq EqualTo
130  , binary eqBoolean opNotEq NotEqualTo
131
132  , binary ordBoolean opLessThan LessThan
133  , binary ordBoolean opLessThanOrEq LessThanOrEqualTo
134  , binary ordBoolean opGreaterThan GreaterThan
135  , binary ordBoolean opGreaterThanOrEq GreaterThanOrEqualTo
136  , binary ordChar opLessThan LessThan
137  , binary ordChar opLessThanOrEq LessThanOrEqualTo
138  , binary ordChar opGreaterThan GreaterThan
139  , binary ordChar opGreaterThanOrEq GreaterThanOrEqualTo
140  , binary ordInt opLessThan LessThan
141  , binary ordInt opLessThanOrEq LessThanOrEqualTo
142  , binary ordInt opGreaterThan GreaterThan
143  , binary ordInt opGreaterThanOrEq GreaterThanOrEqualTo
144  , binary ordNumber opLessThan LessThan
145  , binary ordNumber opLessThanOrEq LessThanOrEqualTo
146  , binary ordNumber opGreaterThan GreaterThan
147  , binary ordNumber opGreaterThanOrEq GreaterThanOrEqualTo
148  , binary ordString opLessThan LessThan
149  , binary ordString opLessThanOrEq LessThanOrEqualTo
150  , binary ordString opGreaterThan GreaterThan
151  , binary ordString opGreaterThanOrEq GreaterThanOrEqualTo
152
153  , binary semigroupString opAppend Add
154
155  , binary heytingAlgebraBoolean opConj And
156  , binary heytingAlgebraBoolean opDisj Or
157  , unary  heytingAlgebraBoolean opNot Not
158
159  , binary' C.dataIntBits C.or BitwiseOr
160  , binary' C.dataIntBits C.and BitwiseAnd
161  , binary' C.dataIntBits C.xor BitwiseXor
162  , binary' C.dataIntBits C.shl ShiftLeft
163  , binary' C.dataIntBits C.shr ShiftRight
164  , binary' C.dataIntBits C.zshr ZeroFillShiftRight
165  , unary'  C.dataIntBits C.complement BitwiseNot
166
167  , inlineNonClassFunction (isModFn (C.dataFunction, C.apply)) $ \f x -> App Nothing f [x]
168  , inlineNonClassFunction (isModFn (C.dataFunction, C.applyFlipped)) $ \x f -> App Nothing f [x]
169  , inlineNonClassFunction (isModFnWithDict (C.dataArray, C.unsafeIndex)) $ flip (Indexer Nothing)
170  ] ++
171  [ fn | i <- [0..10], fn <- [ mkFn i, runFn i ] ] ++
172  [ fn | i <- [0..10], fn <- [ mkEffFn C.controlMonadEffUncurried C.mkEffFn i, runEffFn C.controlMonadEffUncurried C.runEffFn i ] ] ++
173  [ fn | i <- [0..10], fn <- [ mkEffFn C.effectUncurried C.mkEffectFn i, runEffFn C.effectUncurried C.runEffectFn i ] ]
174  where
175  binary :: (Text, PSString) -> (Text, PSString) -> BinaryOperator -> AST -> AST
176  binary dict fns op = convert where
177    convert :: AST -> AST
178    convert (App ss (App _ (App _ fn [dict']) [x]) [y]) | isDict dict dict' && isDict fns fn = Binary ss op x y
179    convert other = other
180  binary' :: Text -> PSString -> BinaryOperator -> AST -> AST
181  binary' moduleName opString op = convert where
182    convert :: AST -> AST
183    convert (App ss (App _ fn [x]) [y]) | isDict (moduleName, opString) fn = Binary ss op x y
184    convert other = other
185  unary :: (Text, PSString) -> (Text, PSString) -> UnaryOperator -> AST -> AST
186  unary dicts fns op = convert where
187    convert :: AST -> AST
188    convert (App ss (App _ fn [dict']) [x]) | isDict dicts dict' && isDict fns fn = Unary ss op x
189    convert other = other
190  unary' :: Text -> PSString -> UnaryOperator -> AST -> AST
191  unary' moduleName fnName op = convert where
192    convert :: AST -> AST
193    convert (App ss fn [x]) | isDict (moduleName, fnName) fn = Unary ss op x
194    convert other = other
195
196  mkFn :: Int -> AST -> AST
197  mkFn = mkFn' C.dataFunctionUncurried C.mkFn $ \ss1 ss2 ss3 args js ->
198    Function ss1 Nothing args (Block ss2 [Return ss3 js])
199
200  mkEffFn :: Text -> Text -> Int -> AST -> AST
201  mkEffFn modName fnName = mkFn' modName fnName $ \ss1 ss2 ss3 args js ->
202    Function ss1 Nothing args (Block ss2 [Return ss3 (App ss3 js [])])
203
204  mkFn' :: Text -> Text -> (Maybe SourceSpan -> Maybe SourceSpan -> Maybe SourceSpan -> [Text] -> AST -> AST) -> Int -> AST -> AST
205  mkFn' modName fnName res 0 = convert where
206    convert :: AST -> AST
207    convert (App _ mkFnN [Function s1 Nothing [_] (Block s2 [Return s3 js])]) | isNFn modName fnName 0 mkFnN =
208      res s1 s2 s3 [] js
209    convert other = other
210  mkFn' modName fnName res n = convert where
211    convert :: AST -> AST
212    convert orig@(App ss mkFnN [fn]) | isNFn modName fnName n mkFnN =
213      case collectArgs n [] fn of
214        Just (args, [Return ss' ret]) -> res ss ss ss' args ret
215        _ -> orig
216    convert other = other
217    collectArgs :: Int -> [Text] -> AST -> Maybe ([Text], [AST])
218    collectArgs 1 acc (Function _ Nothing [oneArg] (Block _ js)) | length acc == n - 1 = Just (reverse (oneArg : acc), js)
219    collectArgs m acc (Function _ Nothing [oneArg] (Block _ [Return _ ret])) = collectArgs (m - 1) (oneArg : acc) ret
220    collectArgs _ _   _ = Nothing
221
222  isNFn :: Text -> Text -> Int -> AST -> Bool
223  isNFn expectMod prefix n (Indexer _ (StringLiteral _ name) (Var _ modName)) | modName == expectMod =
224    name == fromString (T.unpack prefix <> show n)
225  isNFn _ _ _ _ = False
226
227  runFn :: Int -> AST -> AST
228  runFn = runFn' C.dataFunctionUncurried C.runFn App
229
230  runEffFn :: Text -> Text -> Int -> AST -> AST
231  runEffFn modName fnName = runFn' modName fnName $ \ss fn acc ->
232    Function ss Nothing [] (Block ss [Return ss (App ss fn acc)])
233
234  runFn' :: Text -> Text -> (Maybe SourceSpan -> AST -> [AST] -> AST) -> Int -> AST -> AST
235  runFn' modName runFnName res n = convert where
236    convert :: AST -> AST
237    convert js = fromMaybe js $ go n [] js
238
239    go :: Int -> [AST] -> AST -> Maybe AST
240    go 0 acc (App ss runFnN [fn]) | isNFn modName runFnName n runFnN && length acc == n =
241      Just $ res ss fn acc
242    go m acc (App _ lhs [arg]) = go (m - 1) (arg : acc) lhs
243    go _ _   _ = Nothing
244
245  inlineNonClassFunction :: (AST -> Bool) -> (AST -> AST -> AST) -> AST -> AST
246  inlineNonClassFunction p f = convert where
247    convert :: AST -> AST
248    convert (App _ (App _ op' [x]) [y]) | p op' = f x y
249    convert other = other
250
251  isModFn :: (Text, PSString) -> AST -> Bool
252  isModFn (m, op) (Indexer _ (StringLiteral _ op') (Var _ m')) =
253    m == m' && op == op'
254  isModFn _ _ = False
255
256  isModFnWithDict :: (Text, PSString) -> AST -> Bool
257  isModFnWithDict (m, op) (App _ (Indexer _ (StringLiteral _ op') (Var _ m')) [Var _ _]) =
258    m == m' && op == op'
259  isModFnWithDict _ _ = False
260
261-- (f <<< g $ x) = f (g x)
262-- (f <<< g)     = \x -> f (g x)
263inlineFnComposition :: forall m. MonadSupply m => AST -> m AST
264inlineFnComposition = everywhereTopDownM convert where
265  convert :: AST -> m AST
266  convert (App s1 (App s2 (App _ (App _ fn [dict']) [x]) [y]) [z])
267    | isFnCompose dict' fn = return $ App s1 x [App s2 y [z]]
268    | isFnComposeFlipped dict' fn = return $ App s2 y [App s1 x [z]]
269  convert app@(App ss (App _ (App _ fn [dict']) _) _)
270    | isFnCompose dict' fn || isFnComposeFlipped dict' fn = mkApps ss <$> goApps app <*> freshName
271  convert other = return other
272
273  mkApps :: Maybe SourceSpan -> [Either AST (Text, AST)] -> Text -> AST
274  mkApps ss fns a = App ss (Function ss Nothing [] (Block ss $ vars <> [Return Nothing comp])) []
275    where
276    vars = uncurry (VariableIntroduction ss) . fmap Just <$> rights fns
277    comp = Function ss Nothing [a] (Block ss [Return Nothing apps])
278    apps = foldr (\fn acc -> App ss (mkApp fn) [acc]) (Var ss a) fns
279
280  mkApp :: Either AST (Text, AST) -> AST
281  mkApp = either id $ \(name, arg) -> Var (getSourceSpan arg) name
282
283  goApps :: AST -> m [Either AST (Text, AST)]
284  goApps (App _ (App _ (App _ fn [dict']) [x]) [y])
285    | isFnCompose dict' fn = mappend <$> goApps x <*> goApps y
286    | isFnComposeFlipped dict' fn = mappend <$> goApps y <*> goApps x
287  goApps app@App {} = pure . Right . (,app) <$> freshName
288  goApps other = pure [Left other]
289
290  isFnCompose :: AST -> AST -> Bool
291  isFnCompose dict' fn = isDict semigroupoidFn dict' && isDict fnCompose fn
292
293  isFnComposeFlipped :: AST -> AST -> Bool
294  isFnComposeFlipped dict' fn = isDict semigroupoidFn dict' && isDict fnComposeFlipped fn
295
296  fnCompose :: forall a b. (IsString a, IsString b) => (a, b)
297  fnCompose = (C.controlSemigroupoid, C.compose)
298
299  fnComposeFlipped :: forall a b. (IsString a, IsString b) => (a, b)
300  fnComposeFlipped = (C.controlSemigroupoid, C.composeFlipped)
301
302inlineUnsafeCoerce :: AST -> AST
303inlineUnsafeCoerce = everywhereTopDown convert where
304  convert (App _ (Indexer _ (StringLiteral _ unsafeCoerceFn) (Var _ unsafeCoerce)) [ comp ])
305    | unsafeCoerceFn == C.unsafeCoerceFn && unsafeCoerce == C.unsafeCoerce
306    = comp
307  convert other = other
308
309inlineUnsafePartial :: AST -> AST
310inlineUnsafePartial = everywhereTopDown convert where
311  convert (App ss (Indexer _ (StringLiteral _ unsafePartial) (Var _ partialUnsafe)) [ comp ])
312    | unsafePartial == C.unsafePartial && partialUnsafe == C.partialUnsafe
313    -- Apply to undefined here, the application should be optimized away
314    -- if it is safe to do so
315    = App ss comp [ Var ss C.undefined ]
316  convert other = other
317
318semiringNumber :: forall a b. (IsString a, IsString b) => (a, b)
319semiringNumber = (C.dataSemiring, C.semiringNumber)
320
321semiringInt :: forall a b. (IsString a, IsString b) => (a, b)
322semiringInt = (C.dataSemiring, C.semiringInt)
323
324ringNumber :: forall a b. (IsString a, IsString b) => (a, b)
325ringNumber = (C.dataRing, C.ringNumber)
326
327ringInt :: forall a b. (IsString a, IsString b) => (a, b)
328ringInt = (C.dataRing, C.ringInt)
329
330euclideanRingNumber :: forall a b. (IsString a, IsString b) => (a, b)
331euclideanRingNumber = (C.dataEuclideanRing, C.euclideanRingNumber)
332
333eqNumber :: forall a b. (IsString a, IsString b) => (a, b)
334eqNumber = (C.dataEq, C.eqNumber)
335
336eqInt :: forall a b. (IsString a, IsString b) => (a, b)
337eqInt = (C.dataEq, C.eqInt)
338
339eqString :: forall a b. (IsString a, IsString b) => (a, b)
340eqString = (C.dataEq, C.eqString)
341
342eqChar :: forall a b. (IsString a, IsString b) => (a, b)
343eqChar = (C.dataEq, C.eqChar)
344
345eqBoolean :: forall a b. (IsString a, IsString b) => (a, b)
346eqBoolean = (C.dataEq, C.eqBoolean)
347
348ordBoolean :: forall a b. (IsString a, IsString b) => (a, b)
349ordBoolean = (C.dataOrd, C.ordBoolean)
350
351ordNumber :: forall a b. (IsString a, IsString b) => (a, b)
352ordNumber = (C.dataOrd, C.ordNumber)
353
354ordInt :: forall a b. (IsString a, IsString b) => (a, b)
355ordInt = (C.dataOrd, C.ordInt)
356
357ordString :: forall a b. (IsString a, IsString b) => (a, b)
358ordString = (C.dataOrd, C.ordString)
359
360ordChar :: forall a b. (IsString a, IsString b) => (a, b)
361ordChar = (C.dataOrd, C.ordChar)
362
363semigroupString :: forall a b. (IsString a, IsString b) => (a, b)
364semigroupString = (C.dataSemigroup, C.semigroupString)
365
366boundedBoolean :: forall a b. (IsString a, IsString b) => (a, b)
367boundedBoolean = (C.dataBounded, C.boundedBoolean)
368
369heytingAlgebraBoolean :: forall a b. (IsString a, IsString b) => (a, b)
370heytingAlgebraBoolean = (C.dataHeytingAlgebra, C.heytingAlgebraBoolean)
371
372semigroupoidFn :: forall a b. (IsString a, IsString b) => (a, b)
373semigroupoidFn = (C.controlSemigroupoid, C.semigroupoidFn)
374
375opAdd :: forall a b. (IsString a, IsString b) => (a, b)
376opAdd = (C.dataSemiring, C.add)
377
378opMul :: forall a b. (IsString a, IsString b) => (a, b)
379opMul = (C.dataSemiring, C.mul)
380
381opEq :: forall a b. (IsString a, IsString b) => (a, b)
382opEq = (C.dataEq, C.eq)
383
384opNotEq :: forall a b. (IsString a, IsString b) => (a, b)
385opNotEq = (C.dataEq, C.notEq)
386
387opLessThan :: forall a b. (IsString a, IsString b) => (a, b)
388opLessThan = (C.dataOrd, C.lessThan)
389
390opLessThanOrEq :: forall a b. (IsString a, IsString b) => (a, b)
391opLessThanOrEq = (C.dataOrd, C.lessThanOrEq)
392
393opGreaterThan :: forall a b. (IsString a, IsString b) => (a, b)
394opGreaterThan = (C.dataOrd, C.greaterThan)
395
396opGreaterThanOrEq :: forall a b. (IsString a, IsString b) => (a, b)
397opGreaterThanOrEq = (C.dataOrd, C.greaterThanOrEq)
398
399opAppend :: forall a b. (IsString a, IsString b) => (a, b)
400opAppend = (C.dataSemigroup, C.append)
401
402opSub :: forall a b. (IsString a, IsString b) => (a, b)
403opSub = (C.dataRing, C.sub)
404
405opNegate :: forall a b. (IsString a, IsString b) => (a, b)
406opNegate = (C.dataRing, C.negate)
407
408opDiv :: forall a b. (IsString a, IsString b) => (a, b)
409opDiv = (C.dataEuclideanRing, C.div)
410
411opConj :: forall a b. (IsString a, IsString b) => (a, b)
412opConj = (C.dataHeytingAlgebra, C.conj)
413
414opDisj :: forall a b. (IsString a, IsString b) => (a, b)
415opDisj = (C.dataHeytingAlgebra, C.disj)
416
417opNot :: forall a b. (IsString a, IsString b) => (a, b)
418opNot = (C.dataHeytingAlgebra, C.not)
419