1-- | This module performs basic inlining of known functions 2module Language.PureScript.CoreImp.Optimizer.Inliner 3 ( inlineVariables 4 , inlineCommonValues 5 , inlineCommonOperators 6 , inlineFnComposition 7 , inlineUnsafeCoerce 8 , inlineUnsafePartial 9 , etaConvert 10 , unThunk 11 , evaluateIifes 12 ) where 13 14import Prelude.Compat 15 16import Control.Monad.Supply.Class (MonadSupply, freshName) 17 18import Data.Either (rights) 19import Data.Maybe (fromMaybe) 20import Data.String (IsString, fromString) 21import Data.Text (Text) 22import qualified Data.Text as T 23 24import Language.PureScript.PSString (PSString) 25import Language.PureScript.CoreImp.AST 26import Language.PureScript.CoreImp.Optimizer.Common 27import Language.PureScript.AST (SourceSpan(..)) 28import qualified Language.PureScript.Constants.Prelude as C 29import qualified Language.PureScript.Constants.Prim as C 30 31-- TODO: Potential bug: 32-- Shouldn't just inline this case: { var x = 0; x.toFixed(10); } 33-- Needs to be: { 0..toFixed(10); } 34-- Probably needs to be fixed in pretty-printer instead. 35shouldInline :: AST -> Bool 36shouldInline (Var _ _) = True 37shouldInline (NumericLiteral _ _) = True 38shouldInline (StringLiteral _ _) = True 39shouldInline (BooleanLiteral _ _) = True 40shouldInline (Indexer _ index val) = shouldInline index && shouldInline val 41shouldInline _ = False 42 43etaConvert :: AST -> AST 44etaConvert = everywhere convert 45 where 46 convert :: AST -> AST 47 convert (Block ss [Return _ (App _ (Function _ Nothing idents block@(Block _ body)) args)]) 48 | all shouldInline args && 49 not (any ((`isRebound` block) . Var Nothing) idents) && 50 not (any (`isRebound` block) args) 51 = Block ss (map (replaceIdents (zip idents args)) body) 52 convert (Function _ Nothing [] (Block _ [Return _ (App _ fn [])])) = fn 53 convert js = js 54 55unThunk :: AST -> AST 56unThunk = everywhere convert 57 where 58 convert :: AST -> AST 59 convert (Block ss []) = Block ss [] 60 convert (Block ss jss) = 61 case last jss of 62 Return _ (App _ (Function _ Nothing [] (Block _ body)) []) -> Block ss $ init jss ++ body 63 _ -> Block ss jss 64 convert js = js 65 66evaluateIifes :: AST -> AST 67evaluateIifes = everywhere convert 68 where 69 convert :: AST -> AST 70 convert (App _ (Function _ Nothing [] (Block _ [Return _ ret])) []) = ret 71 convert (App _ (Function _ Nothing idents (Block _ [Return ss ret])) []) 72 | not (any (`isReassigned` ret) idents) = replaceIdents (map (, Var ss C.undefined) idents) ret 73 convert js = js 74 75inlineVariables :: AST -> AST 76inlineVariables = everywhere $ removeFromBlock go 77 where 78 go :: [AST] -> [AST] 79 go [] = [] 80 go (VariableIntroduction _ var (Just js) : sts) 81 | shouldInline js && not (any (isReassigned var) sts) && not (any (isRebound js) sts) && not (any (isUpdated var) sts) = 82 go (map (replaceIdent var js) sts) 83 go (s:sts) = s : go sts 84 85inlineCommonValues :: AST -> AST 86inlineCommonValues = everywhere convert 87 where 88 convert :: AST -> AST 89 convert (App ss fn [dict]) 90 | isDict' [semiringNumber, semiringInt] dict && isDict fnZero fn = NumericLiteral ss (Left 0) 91 | isDict' [semiringNumber, semiringInt] dict && isDict fnOne fn = NumericLiteral ss (Left 1) 92 | isDict boundedBoolean dict && isDict fnBottom fn = BooleanLiteral ss False 93 | isDict boundedBoolean dict && isDict fnTop fn = BooleanLiteral ss True 94 convert (App ss (App _ fn [dict]) [x]) 95 | isDict ringInt dict && isDict fnNegate fn = Binary ss BitwiseOr (Unary ss Negate x) (NumericLiteral ss (Left 0)) 96 convert (App ss (App _ (App _ fn [dict]) [x]) [y]) 97 | isDict semiringInt dict && isDict fnAdd fn = intOp ss Add x y 98 | isDict semiringInt dict && isDict fnMultiply fn = intOp ss Multiply x y 99 | isDict ringInt dict && isDict fnSubtract fn = intOp ss Subtract x y 100 convert other = other 101 fnZero = (C.dataSemiring, C.zero) 102 fnOne = (C.dataSemiring, C.one) 103 fnBottom = (C.dataBounded, C.bottom) 104 fnTop = (C.dataBounded, C.top) 105 fnAdd = (C.dataSemiring, C.add) 106 fnMultiply = (C.dataSemiring, C.mul) 107 fnSubtract = (C.dataRing, C.sub) 108 fnNegate = (C.dataRing, C.negate) 109 intOp ss op x y = Binary ss BitwiseOr (Binary ss op x y) (NumericLiteral ss (Left 0)) 110 111inlineCommonOperators :: AST -> AST 112inlineCommonOperators = everywhereTopDown $ applyAll $ 113 [ binary semiringNumber opAdd Add 114 , binary semiringNumber opMul Multiply 115 116 , binary ringNumber opSub Subtract 117 , unary ringNumber opNegate Negate 118 119 , binary euclideanRingNumber opDiv Divide 120 121 , binary eqNumber opEq EqualTo 122 , binary eqNumber opNotEq NotEqualTo 123 , binary eqInt opEq EqualTo 124 , binary eqInt opNotEq NotEqualTo 125 , binary eqString opEq EqualTo 126 , binary eqString opNotEq NotEqualTo 127 , binary eqChar opEq EqualTo 128 , binary eqChar opNotEq NotEqualTo 129 , binary eqBoolean opEq EqualTo 130 , binary eqBoolean opNotEq NotEqualTo 131 132 , binary ordBoolean opLessThan LessThan 133 , binary ordBoolean opLessThanOrEq LessThanOrEqualTo 134 , binary ordBoolean opGreaterThan GreaterThan 135 , binary ordBoolean opGreaterThanOrEq GreaterThanOrEqualTo 136 , binary ordChar opLessThan LessThan 137 , binary ordChar opLessThanOrEq LessThanOrEqualTo 138 , binary ordChar opGreaterThan GreaterThan 139 , binary ordChar opGreaterThanOrEq GreaterThanOrEqualTo 140 , binary ordInt opLessThan LessThan 141 , binary ordInt opLessThanOrEq LessThanOrEqualTo 142 , binary ordInt opGreaterThan GreaterThan 143 , binary ordInt opGreaterThanOrEq GreaterThanOrEqualTo 144 , binary ordNumber opLessThan LessThan 145 , binary ordNumber opLessThanOrEq LessThanOrEqualTo 146 , binary ordNumber opGreaterThan GreaterThan 147 , binary ordNumber opGreaterThanOrEq GreaterThanOrEqualTo 148 , binary ordString opLessThan LessThan 149 , binary ordString opLessThanOrEq LessThanOrEqualTo 150 , binary ordString opGreaterThan GreaterThan 151 , binary ordString opGreaterThanOrEq GreaterThanOrEqualTo 152 153 , binary semigroupString opAppend Add 154 155 , binary heytingAlgebraBoolean opConj And 156 , binary heytingAlgebraBoolean opDisj Or 157 , unary heytingAlgebraBoolean opNot Not 158 159 , binary' C.dataIntBits C.or BitwiseOr 160 , binary' C.dataIntBits C.and BitwiseAnd 161 , binary' C.dataIntBits C.xor BitwiseXor 162 , binary' C.dataIntBits C.shl ShiftLeft 163 , binary' C.dataIntBits C.shr ShiftRight 164 , binary' C.dataIntBits C.zshr ZeroFillShiftRight 165 , unary' C.dataIntBits C.complement BitwiseNot 166 167 , inlineNonClassFunction (isModFn (C.dataFunction, C.apply)) $ \f x -> App Nothing f [x] 168 , inlineNonClassFunction (isModFn (C.dataFunction, C.applyFlipped)) $ \x f -> App Nothing f [x] 169 , inlineNonClassFunction (isModFnWithDict (C.dataArray, C.unsafeIndex)) $ flip (Indexer Nothing) 170 ] ++ 171 [ fn | i <- [0..10], fn <- [ mkFn i, runFn i ] ] ++ 172 [ fn | i <- [0..10], fn <- [ mkEffFn C.controlMonadEffUncurried C.mkEffFn i, runEffFn C.controlMonadEffUncurried C.runEffFn i ] ] ++ 173 [ fn | i <- [0..10], fn <- [ mkEffFn C.effectUncurried C.mkEffectFn i, runEffFn C.effectUncurried C.runEffectFn i ] ] 174 where 175 binary :: (Text, PSString) -> (Text, PSString) -> BinaryOperator -> AST -> AST 176 binary dict fns op = convert where 177 convert :: AST -> AST 178 convert (App ss (App _ (App _ fn [dict']) [x]) [y]) | isDict dict dict' && isDict fns fn = Binary ss op x y 179 convert other = other 180 binary' :: Text -> PSString -> BinaryOperator -> AST -> AST 181 binary' moduleName opString op = convert where 182 convert :: AST -> AST 183 convert (App ss (App _ fn [x]) [y]) | isDict (moduleName, opString) fn = Binary ss op x y 184 convert other = other 185 unary :: (Text, PSString) -> (Text, PSString) -> UnaryOperator -> AST -> AST 186 unary dicts fns op = convert where 187 convert :: AST -> AST 188 convert (App ss (App _ fn [dict']) [x]) | isDict dicts dict' && isDict fns fn = Unary ss op x 189 convert other = other 190 unary' :: Text -> PSString -> UnaryOperator -> AST -> AST 191 unary' moduleName fnName op = convert where 192 convert :: AST -> AST 193 convert (App ss fn [x]) | isDict (moduleName, fnName) fn = Unary ss op x 194 convert other = other 195 196 mkFn :: Int -> AST -> AST 197 mkFn = mkFn' C.dataFunctionUncurried C.mkFn $ \ss1 ss2 ss3 args js -> 198 Function ss1 Nothing args (Block ss2 [Return ss3 js]) 199 200 mkEffFn :: Text -> Text -> Int -> AST -> AST 201 mkEffFn modName fnName = mkFn' modName fnName $ \ss1 ss2 ss3 args js -> 202 Function ss1 Nothing args (Block ss2 [Return ss3 (App ss3 js [])]) 203 204 mkFn' :: Text -> Text -> (Maybe SourceSpan -> Maybe SourceSpan -> Maybe SourceSpan -> [Text] -> AST -> AST) -> Int -> AST -> AST 205 mkFn' modName fnName res 0 = convert where 206 convert :: AST -> AST 207 convert (App _ mkFnN [Function s1 Nothing [_] (Block s2 [Return s3 js])]) | isNFn modName fnName 0 mkFnN = 208 res s1 s2 s3 [] js 209 convert other = other 210 mkFn' modName fnName res n = convert where 211 convert :: AST -> AST 212 convert orig@(App ss mkFnN [fn]) | isNFn modName fnName n mkFnN = 213 case collectArgs n [] fn of 214 Just (args, [Return ss' ret]) -> res ss ss ss' args ret 215 _ -> orig 216 convert other = other 217 collectArgs :: Int -> [Text] -> AST -> Maybe ([Text], [AST]) 218 collectArgs 1 acc (Function _ Nothing [oneArg] (Block _ js)) | length acc == n - 1 = Just (reverse (oneArg : acc), js) 219 collectArgs m acc (Function _ Nothing [oneArg] (Block _ [Return _ ret])) = collectArgs (m - 1) (oneArg : acc) ret 220 collectArgs _ _ _ = Nothing 221 222 isNFn :: Text -> Text -> Int -> AST -> Bool 223 isNFn expectMod prefix n (Indexer _ (StringLiteral _ name) (Var _ modName)) | modName == expectMod = 224 name == fromString (T.unpack prefix <> show n) 225 isNFn _ _ _ _ = False 226 227 runFn :: Int -> AST -> AST 228 runFn = runFn' C.dataFunctionUncurried C.runFn App 229 230 runEffFn :: Text -> Text -> Int -> AST -> AST 231 runEffFn modName fnName = runFn' modName fnName $ \ss fn acc -> 232 Function ss Nothing [] (Block ss [Return ss (App ss fn acc)]) 233 234 runFn' :: Text -> Text -> (Maybe SourceSpan -> AST -> [AST] -> AST) -> Int -> AST -> AST 235 runFn' modName runFnName res n = convert where 236 convert :: AST -> AST 237 convert js = fromMaybe js $ go n [] js 238 239 go :: Int -> [AST] -> AST -> Maybe AST 240 go 0 acc (App ss runFnN [fn]) | isNFn modName runFnName n runFnN && length acc == n = 241 Just $ res ss fn acc 242 go m acc (App _ lhs [arg]) = go (m - 1) (arg : acc) lhs 243 go _ _ _ = Nothing 244 245 inlineNonClassFunction :: (AST -> Bool) -> (AST -> AST -> AST) -> AST -> AST 246 inlineNonClassFunction p f = convert where 247 convert :: AST -> AST 248 convert (App _ (App _ op' [x]) [y]) | p op' = f x y 249 convert other = other 250 251 isModFn :: (Text, PSString) -> AST -> Bool 252 isModFn (m, op) (Indexer _ (StringLiteral _ op') (Var _ m')) = 253 m == m' && op == op' 254 isModFn _ _ = False 255 256 isModFnWithDict :: (Text, PSString) -> AST -> Bool 257 isModFnWithDict (m, op) (App _ (Indexer _ (StringLiteral _ op') (Var _ m')) [Var _ _]) = 258 m == m' && op == op' 259 isModFnWithDict _ _ = False 260 261-- (f <<< g $ x) = f (g x) 262-- (f <<< g) = \x -> f (g x) 263inlineFnComposition :: forall m. MonadSupply m => AST -> m AST 264inlineFnComposition = everywhereTopDownM convert where 265 convert :: AST -> m AST 266 convert (App s1 (App s2 (App _ (App _ fn [dict']) [x]) [y]) [z]) 267 | isFnCompose dict' fn = return $ App s1 x [App s2 y [z]] 268 | isFnComposeFlipped dict' fn = return $ App s2 y [App s1 x [z]] 269 convert app@(App ss (App _ (App _ fn [dict']) _) _) 270 | isFnCompose dict' fn || isFnComposeFlipped dict' fn = mkApps ss <$> goApps app <*> freshName 271 convert other = return other 272 273 mkApps :: Maybe SourceSpan -> [Either AST (Text, AST)] -> Text -> AST 274 mkApps ss fns a = App ss (Function ss Nothing [] (Block ss $ vars <> [Return Nothing comp])) [] 275 where 276 vars = uncurry (VariableIntroduction ss) . fmap Just <$> rights fns 277 comp = Function ss Nothing [a] (Block ss [Return Nothing apps]) 278 apps = foldr (\fn acc -> App ss (mkApp fn) [acc]) (Var ss a) fns 279 280 mkApp :: Either AST (Text, AST) -> AST 281 mkApp = either id $ \(name, arg) -> Var (getSourceSpan arg) name 282 283 goApps :: AST -> m [Either AST (Text, AST)] 284 goApps (App _ (App _ (App _ fn [dict']) [x]) [y]) 285 | isFnCompose dict' fn = mappend <$> goApps x <*> goApps y 286 | isFnComposeFlipped dict' fn = mappend <$> goApps y <*> goApps x 287 goApps app@App {} = pure . Right . (,app) <$> freshName 288 goApps other = pure [Left other] 289 290 isFnCompose :: AST -> AST -> Bool 291 isFnCompose dict' fn = isDict semigroupoidFn dict' && isDict fnCompose fn 292 293 isFnComposeFlipped :: AST -> AST -> Bool 294 isFnComposeFlipped dict' fn = isDict semigroupoidFn dict' && isDict fnComposeFlipped fn 295 296 fnCompose :: forall a b. (IsString a, IsString b) => (a, b) 297 fnCompose = (C.controlSemigroupoid, C.compose) 298 299 fnComposeFlipped :: forall a b. (IsString a, IsString b) => (a, b) 300 fnComposeFlipped = (C.controlSemigroupoid, C.composeFlipped) 301 302inlineUnsafeCoerce :: AST -> AST 303inlineUnsafeCoerce = everywhereTopDown convert where 304 convert (App _ (Indexer _ (StringLiteral _ unsafeCoerceFn) (Var _ unsafeCoerce)) [ comp ]) 305 | unsafeCoerceFn == C.unsafeCoerceFn && unsafeCoerce == C.unsafeCoerce 306 = comp 307 convert other = other 308 309inlineUnsafePartial :: AST -> AST 310inlineUnsafePartial = everywhereTopDown convert where 311 convert (App ss (Indexer _ (StringLiteral _ unsafePartial) (Var _ partialUnsafe)) [ comp ]) 312 | unsafePartial == C.unsafePartial && partialUnsafe == C.partialUnsafe 313 -- Apply to undefined here, the application should be optimized away 314 -- if it is safe to do so 315 = App ss comp [ Var ss C.undefined ] 316 convert other = other 317 318semiringNumber :: forall a b. (IsString a, IsString b) => (a, b) 319semiringNumber = (C.dataSemiring, C.semiringNumber) 320 321semiringInt :: forall a b. (IsString a, IsString b) => (a, b) 322semiringInt = (C.dataSemiring, C.semiringInt) 323 324ringNumber :: forall a b. (IsString a, IsString b) => (a, b) 325ringNumber = (C.dataRing, C.ringNumber) 326 327ringInt :: forall a b. (IsString a, IsString b) => (a, b) 328ringInt = (C.dataRing, C.ringInt) 329 330euclideanRingNumber :: forall a b. (IsString a, IsString b) => (a, b) 331euclideanRingNumber = (C.dataEuclideanRing, C.euclideanRingNumber) 332 333eqNumber :: forall a b. (IsString a, IsString b) => (a, b) 334eqNumber = (C.dataEq, C.eqNumber) 335 336eqInt :: forall a b. (IsString a, IsString b) => (a, b) 337eqInt = (C.dataEq, C.eqInt) 338 339eqString :: forall a b. (IsString a, IsString b) => (a, b) 340eqString = (C.dataEq, C.eqString) 341 342eqChar :: forall a b. (IsString a, IsString b) => (a, b) 343eqChar = (C.dataEq, C.eqChar) 344 345eqBoolean :: forall a b. (IsString a, IsString b) => (a, b) 346eqBoolean = (C.dataEq, C.eqBoolean) 347 348ordBoolean :: forall a b. (IsString a, IsString b) => (a, b) 349ordBoolean = (C.dataOrd, C.ordBoolean) 350 351ordNumber :: forall a b. (IsString a, IsString b) => (a, b) 352ordNumber = (C.dataOrd, C.ordNumber) 353 354ordInt :: forall a b. (IsString a, IsString b) => (a, b) 355ordInt = (C.dataOrd, C.ordInt) 356 357ordString :: forall a b. (IsString a, IsString b) => (a, b) 358ordString = (C.dataOrd, C.ordString) 359 360ordChar :: forall a b. (IsString a, IsString b) => (a, b) 361ordChar = (C.dataOrd, C.ordChar) 362 363semigroupString :: forall a b. (IsString a, IsString b) => (a, b) 364semigroupString = (C.dataSemigroup, C.semigroupString) 365 366boundedBoolean :: forall a b. (IsString a, IsString b) => (a, b) 367boundedBoolean = (C.dataBounded, C.boundedBoolean) 368 369heytingAlgebraBoolean :: forall a b. (IsString a, IsString b) => (a, b) 370heytingAlgebraBoolean = (C.dataHeytingAlgebra, C.heytingAlgebraBoolean) 371 372semigroupoidFn :: forall a b. (IsString a, IsString b) => (a, b) 373semigroupoidFn = (C.controlSemigroupoid, C.semigroupoidFn) 374 375opAdd :: forall a b. (IsString a, IsString b) => (a, b) 376opAdd = (C.dataSemiring, C.add) 377 378opMul :: forall a b. (IsString a, IsString b) => (a, b) 379opMul = (C.dataSemiring, C.mul) 380 381opEq :: forall a b. (IsString a, IsString b) => (a, b) 382opEq = (C.dataEq, C.eq) 383 384opNotEq :: forall a b. (IsString a, IsString b) => (a, b) 385opNotEq = (C.dataEq, C.notEq) 386 387opLessThan :: forall a b. (IsString a, IsString b) => (a, b) 388opLessThan = (C.dataOrd, C.lessThan) 389 390opLessThanOrEq :: forall a b. (IsString a, IsString b) => (a, b) 391opLessThanOrEq = (C.dataOrd, C.lessThanOrEq) 392 393opGreaterThan :: forall a b. (IsString a, IsString b) => (a, b) 394opGreaterThan = (C.dataOrd, C.greaterThan) 395 396opGreaterThanOrEq :: forall a b. (IsString a, IsString b) => (a, b) 397opGreaterThanOrEq = (C.dataOrd, C.greaterThanOrEq) 398 399opAppend :: forall a b. (IsString a, IsString b) => (a, b) 400opAppend = (C.dataSemigroup, C.append) 401 402opSub :: forall a b. (IsString a, IsString b) => (a, b) 403opSub = (C.dataRing, C.sub) 404 405opNegate :: forall a b. (IsString a, IsString b) => (a, b) 406opNegate = (C.dataRing, C.negate) 407 408opDiv :: forall a b. (IsString a, IsString b) => (a, b) 409opDiv = (C.dataEuclideanRing, C.div) 410 411opConj :: forall a b. (IsString a, IsString b) => (a, b) 412opConj = (C.dataHeytingAlgebra, C.conj) 413 414opDisj :: forall a b. (IsString a, IsString b) => (a, b) 415opDisj = (C.dataHeytingAlgebra, C.disj) 416 417opNot :: forall a b. (IsString a, IsString b) => (a, b) 418opNot = (C.dataHeytingAlgebra, C.not) 419