summaryrefslogtreecommitdiffhomepage
path: root/lib/Language/Haskell/Stylish/GHC.hs
blob: ee2d59fc83c5cac348ba06b54ee129c60dc05dff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -Wno-missing-fields #-}
-- | Utility functions for working with the GHC AST
module Language.Haskell.Stylish.GHC
  ( dropAfterLocated
  , dropBeforeLocated
  , dropBeforeAndAfter
    -- * Unsafe getters
  , getEndLineUnsafe
  , getStartLineUnsafe
    -- * Standard settings
  , baseDynFlags
    -- * Positions
  , unLocated
    -- * Outputable operators
  , showOutputable
  , compareOutputable
  ) where

--------------------------------------------------------------------------------
import           Data.Function (on)

--------------------------------------------------------------------------------
import           DynFlags                        (Settings(..), defaultDynFlags)
import qualified DynFlags                        as GHC
import           FileSettings                    (FileSettings(..))
import           GHC.Fingerprint                 (fingerprint0)
import           GHC.Platform
import           GHC.Version                     (cProjectVersion)
import           GhcNameVersion                  (GhcNameVersion(..))
import           PlatformConstants               (PlatformConstants(..))
import           SrcLoc                          (GenLocated(..), SrcSpan(..))
import           SrcLoc                          (Located, RealLocated)
import           SrcLoc                          (srcSpanStartLine, srcSpanEndLine)
import           ToolSettings                    (ToolSettings(..))
import qualified Outputable                      as GHC

getStartLineUnsafe :: Located a -> Int
getStartLineUnsafe = \case
  (L (RealSrcSpan s) _) -> srcSpanStartLine s
  _ -> error "could not get start line of block"

getEndLineUnsafe :: Located a -> Int
getEndLineUnsafe = \case
  (L (RealSrcSpan s) _) -> srcSpanEndLine s
  _ -> error "could not get end line of block"

dropAfterLocated :: Maybe (Located a) -> [RealLocated b] -> [RealLocated b]
dropAfterLocated loc xs = case loc of
  Just (L (RealSrcSpan rloc) _) ->
    filter (\(L x _) -> srcSpanEndLine rloc >= srcSpanStartLine x) xs
  _ -> xs

dropBeforeLocated :: Maybe (Located a) -> [RealLocated b] -> [RealLocated b]
dropBeforeLocated loc xs = case loc of
  Just (L (RealSrcSpan rloc) _) ->
    filter (\(L x _) -> srcSpanStartLine rloc <= srcSpanEndLine x) xs
  _ -> xs

dropBeforeAndAfter :: Located a -> [RealLocated b] -> [RealLocated b]
dropBeforeAndAfter loc = dropBeforeLocated (Just loc) . dropAfterLocated (Just loc)

baseDynFlags :: GHC.DynFlags
baseDynFlags = defaultDynFlags fakeSettings llvmConfig
  where
    fakeSettings = GHC.Settings
      { sGhcNameVersion = GhcNameVersion "stylish-haskell" cProjectVersion
      , sFileSettings = FileSettings {}
      , sToolSettings = ToolSettings
        { toolSettings_opt_P_fingerprint = fingerprint0,
          toolSettings_pgm_F = ""
        }
      , sPlatformConstants = PlatformConstants
        { pc_DYNAMIC_BY_DEFAULT = False
        , pc_WORD_SIZE = 8
        }
      , sTargetPlatform = Platform
        { platformMini = PlatformMini
          { platformMini_arch = ArchUnknown
          , platformMini_os = OSUnknown
          }
        , platformWordSize = PW8
        , platformUnregisterised = True
        , platformHasIdentDirective = False
        , platformHasSubsectionsViaSymbols = False
        , platformIsCrossCompiling = False
        }
      , sPlatformMisc = PlatformMisc {}
      , sRawSettings = []
      }

    llvmConfig = GHC.LlvmConfig [] []

unLocated :: Located a -> a
unLocated (L _ a) = a

showOutputable :: GHC.Outputable a => a -> String
showOutputable = GHC.showPpr baseDynFlags

compareOutputable :: GHC.Outputable a => a -> a -> Ordering
compareOutputable = compare `on` showOutputable