diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Language/Haskell/Stylish.hs | 17 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Align.hs | 53 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Block.hs | 30 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Config.hs | 95 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/GHC.hs | 103 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Module.hs | 283 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Ordering.hs | 61 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Parse.hs | 148 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Printer.hs | 458 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step.hs | 14 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/Data.hs | 614 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/Imports.hs | 784 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/LanguagePragmas.hs | 112 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/ModuleHeader.hs | 222 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/SimpleAlign.hs | 224 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/Squash.hs | 71 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Step/UnicodeSyntax.hs | 57 | ||||
-rw-r--r-- | lib/Language/Haskell/Stylish/Util.hs | 126 |
18 files changed, 2638 insertions, 834 deletions
diff --git a/lib/Language/Haskell/Stylish.hs b/lib/Language/Haskell/Stylish.hs index c50db4d..a767889 100644 --- a/lib/Language/Haskell/Stylish.hs +++ b/lib/Language/Haskell/Stylish.hs @@ -91,14 +91,19 @@ unicodeSyntax = UnicodeSyntax.step -------------------------------------------------------------------------------- runStep :: Extensions -> Maybe FilePath -> Lines -> Step -> Either String Lines -runStep exts mfp ls step = - stepFilter step ls <$> parseModule exts mfp (unlines ls) - +runStep exts mfp ls = \case + Step _name step -> + step ls <$> parseModule exts mfp (unlines ls) -------------------------------------------------------------------------------- -runSteps :: Extensions -> Maybe FilePath -> [Step] -> Lines - -> Either String Lines -runSteps exts mfp steps ls = foldM (runStep exts mfp) ls steps +runSteps :: + Extensions + -> Maybe FilePath + -> [Step] + -> Lines + -> Either String Lines +runSteps exts mfp steps ls = + foldM (runStep exts mfp) ls steps newtype ConfigPath = ConfigPath { unConfigPath :: FilePath } diff --git a/lib/Language/Haskell/Stylish/Align.hs b/lib/Language/Haskell/Stylish/Align.hs index 1f28d7a..c8a092f 100644 --- a/lib/Language/Haskell/Stylish/Align.hs +++ b/lib/Language/Haskell/Stylish/Align.hs @@ -8,7 +8,7 @@ module Language.Haskell.Stylish.Align -------------------------------------------------------------------------------- import Data.List (nub) -import qualified Language.Haskell.Exts as H +import qualified SrcLoc as S -------------------------------------------------------------------------------- @@ -51,49 +51,48 @@ data Alignable a = Alignable , aRightLead :: !Int } deriving (Show) - -------------------------------------------------------------------------------- -- | Create changes that perform the alignment. + align - :: Maybe Int -- ^ Max columns - -> [Alignable H.SrcSpan] -- ^ Alignables - -> [Change String] -- ^ Changes performing the alignment. + :: Maybe Int -- ^ Max columns + -> [Alignable S.RealSrcSpan] -- ^ Alignables + -> [Change String] -- ^ Changes performing the alignment align _ [] = [] align maxColumns alignment - -- Do not make any change if we would go past the maximum number of columns. - | exceedsColumns (longestLeft + longestRight) = [] - | not (fixable alignment) = [] - | otherwise = map align' alignment + -- Do not make an changes if we would go past the maximum number of columns + | exceedsColumns (longestLeft + longestRight) = [] + | not (fixable alignment) = [] + | otherwise = map align' alignment where exceedsColumns i = case maxColumns of - Nothing -> False -- No number exceeds a maximum column count of - -- Nothing, because there is no limit to exceed. - Just c -> i > c + Nothing -> False + Just c -> i > c - -- The longest thing in the left column. - longestLeft = maximum $ map (H.srcSpanEndColumn . aLeft) alignment + -- The longest thing in the left column + longestLeft = maximum $ map (S.srcSpanEndCol . aLeft) alignment - -- The longest thing in the right column. + -- The longest thing in the right column longestRight = maximum - [ H.srcSpanEndColumn (aRight a) - H.srcSpanStartColumn (aRight a) - + aRightLead a - | a <- alignment - ] - - align' a = changeLine (H.srcSpanStartLine $ aContainer a) $ \str -> - let column = H.srcSpanEndColumn $ aLeft a - (pre, post) = splitAt column str - in [padRight longestLeft (trimRight pre) ++ trimLeft post] + [ S.srcSpanEndCol (aRight a) - S.srcSpanStartCol (aRight a) + + aRightLead a + | a <- alignment + ] + align' a = changeLine (S.srcSpanStartLine $ aContainer a) $ \str -> + let column = S.srcSpanEndCol $ aLeft a + (pre, post) = splitAt column str + in [padRight longestLeft (trimRight pre) ++ trimLeft post] -------------------------------------------------------------------------------- -- | Checks that all the alignables appear on a single line, and that they do -- not overlap. -fixable :: [Alignable H.SrcSpan] -> Bool + +fixable :: [Alignable S.RealSrcSpan] -> Bool fixable [] = False fixable [_] = False fixable fields = all singleLine containers && nonOverlapping containers where containers = map aContainer fields - singleLine s = H.srcSpanStartLine s == H.srcSpanEndLine s - nonOverlapping ss = length ss == length (nub $ map H.srcSpanStartLine ss) + singleLine s = S.srcSpanStartLine s == S.srcSpanEndLine s + nonOverlapping ss = length ss == length (nub $ map S.srcSpanStartLine ss) diff --git a/lib/Language/Haskell/Stylish/Block.hs b/lib/Language/Haskell/Stylish/Block.hs index 46111ee..9b07420 100644 --- a/lib/Language/Haskell/Stylish/Block.hs +++ b/lib/Language/Haskell/Stylish/Block.hs @@ -4,20 +4,17 @@ module Language.Haskell.Stylish.Block , LineBlock , SpanBlock , blockLength - , linesFromSrcSpan - , spanFromSrcSpan , moveBlock , adjacent , merge + , mergeAdjacent , overlapping , groupAdjacent ) where -------------------------------------------------------------------------------- -import Control.Arrow (arr, (&&&), (>>>)) -import qualified Data.IntSet as IS -import qualified Language.Haskell.Exts as H +import qualified Data.IntSet as IS -------------------------------------------------------------------------------- @@ -25,7 +22,8 @@ import qualified Language.Haskell.Exts as H data Block a = Block { blockStart :: Int , blockEnd :: Int - } deriving (Eq, Ord, Show) + } + deriving (Eq, Ord, Show) -------------------------------------------------------------------------------- @@ -40,21 +38,6 @@ type SpanBlock = Block Char blockLength :: Block a -> Int blockLength (Block start end) = end - start + 1 - --------------------------------------------------------------------------------- -linesFromSrcSpan :: H.SrcSpanInfo -> LineBlock -linesFromSrcSpan = H.srcInfoSpan >>> - H.srcSpanStartLine &&& H.srcSpanEndLine >>> - arr (uncurry Block) - - --------------------------------------------------------------------------------- -spanFromSrcSpan :: H.SrcSpanInfo -> SpanBlock -spanFromSrcSpan = H.srcInfoSpan >>> - H.srcSpanStartColumn &&& H.srcSpanEndColumn >>> - arr (uncurry Block) - - -------------------------------------------------------------------------------- moveBlock :: Int -> Block a -> Block a moveBlock offset (Block start end) = Block (start + offset) (end + offset) @@ -94,3 +77,8 @@ groupAdjacent = foldr go [] go (b1, x) gs = case break (adjacent b1 . fst) gs of (_, []) -> (b1, [x]) : gs (ys, ((b2, xs) : zs)) -> (merge b1 b2, x : xs) : (ys ++ zs) + +mergeAdjacent :: [Block a] -> [Block a] +mergeAdjacent (a : b : rest) | a `adjacent` b = merge a b : mergeAdjacent rest +mergeAdjacent (a : rest) = a : mergeAdjacent rest +mergeAdjacent [] = [] diff --git a/lib/Language/Haskell/Stylish/Config.hs b/lib/Language/Haskell/Stylish/Config.hs index 475a5e3..dde9d0d 100644 --- a/lib/Language/Haskell/Stylish/Config.hs +++ b/lib/Language/Haskell/Stylish/Config.hs @@ -1,16 +1,21 @@ -------------------------------------------------------------------------------- +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} module Language.Haskell.Stylish.Config ( Extensions , Config (..) + , ExitCodeBehavior (..) , defaultConfigBytes , configFilePath , loadConfig + , parseConfig ) where -------------------------------------------------------------------------------- +import Control.Applicative ((<|>)) import Control.Monad (forM, mzero) import Data.Aeson (FromJSON (..)) import qualified Data.Aeson as A @@ -41,6 +46,7 @@ import Language.Haskell.Stylish.Step import qualified Language.Haskell.Stylish.Step.Data as Data import qualified Language.Haskell.Stylish.Step.Imports as Imports import qualified Language.Haskell.Stylish.Step.LanguagePragmas as LanguagePragmas +import qualified Language.Haskell.Stylish.Step.ModuleHeader as ModuleHeader import qualified Language.Haskell.Stylish.Step.SimpleAlign as SimpleAlign import qualified Language.Haskell.Stylish.Step.Squash as Squash import qualified Language.Haskell.Stylish.Step.Tabs as Tabs @@ -60,8 +66,18 @@ data Config = Config , configLanguageExtensions :: [String] , configNewline :: IO.Newline , configCabal :: Bool + , configExitCode :: ExitCodeBehavior } +-------------------------------------------------------------------------------- +data ExitCodeBehavior + = NormalExitBehavior + | ErrorOnFormatExitBehavior + deriving (Eq) + +instance Show ExitCodeBehavior where + show NormalExitBehavior = "normal" + show ErrorOnFormatExitBehavior = "error_on_format" -------------------------------------------------------------------------------- instance FromJSON Config where @@ -126,6 +142,7 @@ parseConfig (A.Object o) = do <*> (o A..:? "language_extensions" A..!= []) <*> (o A..:? "newline" >>= parseEnum newlines IO.nativeNewline) <*> (o A..:? "cabal" A..!= True) + <*> (o A..:? "exit_code" >>= parseEnum exitCodes NormalExitBehavior) -- Then fill in the steps based on the partial config we already have stepValues <- o A..: "steps" :: A.Parser [A.Value] @@ -137,6 +154,10 @@ parseConfig (A.Object o) = do , ("lf", IO.LF) , ("crlf", IO.CRLF) ] + exitCodes = + [ ("normal", NormalExitBehavior) + , ("error_on_format", ErrorOnFormatExitBehavior) + ] parseConfig _ = mzero @@ -144,6 +165,7 @@ parseConfig _ = mzero catalog :: Map String (Config -> A.Object -> A.Parser Step) catalog = M.fromList [ ("imports", parseImports) + , ("module_header", parseModuleHeader) , ("records", parseRecords) , ("language_pragmas", parseLanguagePragmas) , ("simple_align", parseSimpleAlign) @@ -172,27 +194,54 @@ parseEnum strs _ (Just k) = case lookup k strs of Nothing -> fail $ "Unknown option: " ++ k ++ ", should be one of: " ++ intercalate ", " (map fst strs) +-------------------------------------------------------------------------------- +parseModuleHeader :: Config -> A.Object -> A.Parser Step +parseModuleHeader _ o = fmap ModuleHeader.step $ ModuleHeader.Config + <$> o A..:? "indent" A..!= ModuleHeader.indent def + <*> o A..:? "sort" A..!= ModuleHeader.sort def + <*> o A..:? "separate_lists" A..!= ModuleHeader.separateLists def + where + def = ModuleHeader.defaultConfig -------------------------------------------------------------------------------- parseSimpleAlign :: Config -> A.Object -> A.Parser Step parseSimpleAlign c o = SimpleAlign.step <$> pure (configColumns c) <*> (SimpleAlign.Config - <$> withDef SimpleAlign.cCases "cases" - <*> withDef SimpleAlign.cTopLevelPatterns "top_level_patterns" - <*> withDef SimpleAlign.cRecords "records") + <$> parseAlign "cases" SimpleAlign.cCases + <*> parseAlign "top_level_patterns" SimpleAlign.cTopLevelPatterns + <*> parseAlign "records" SimpleAlign.cRecords + <*> parseAlign "multi_way_if" SimpleAlign.cMultiWayIf) where - withDef f k = fromMaybe (f SimpleAlign.defaultConfig) <$> (o A..:? k) + parseAlign key f = + (o A..:? key >>= parseEnum aligns (f SimpleAlign.defaultConfig)) <|> + (boolToAlign <$> o A..: key) + aligns = + [ ("always", SimpleAlign.Always) + , ("adjacent", SimpleAlign.Adjacent) + , ("never", SimpleAlign.Never) + ] + boolToAlign True = SimpleAlign.Always + boolToAlign False = SimpleAlign.Never + -------------------------------------------------------------------------------- parseRecords :: Config -> A.Object -> A.Parser Step -parseRecords _ o = Data.step +parseRecords c o = Data.step <$> (Data.Config <$> (o A..: "equals" >>= parseIndent) <*> (o A..: "first_field" >>= parseIndent) <*> (o A..: "field_comment") - <*> (o A..: "deriving")) - + <*> (o A..: "deriving") + <*> (o A..:? "break_enums" A..!= False) + <*> (o A..:? "break_single_constructors" A..!= True) + <*> (o A..: "via" >>= parseIndent) + <*> (o A..:? "curried_context" A..!= False) + <*> (o A..:? "sort_deriving" A..!= True) + <*> pure configMaxColumns) + where + configMaxColumns = + maybe Data.NoMaxColumns Data.MaxColumns (configColumns c) parseIndent :: A.Value -> A.Parser Data.Indent parseIndent = A.withText "Indent" $ \t -> @@ -214,23 +263,21 @@ parseSquash _ _ = return Squash.step -------------------------------------------------------------------------------- parseImports :: Config -> A.Object -> A.Parser Step -parseImports config o = Imports.step - <$> pure (configColumns config) - <*> (Imports.Options - <$> (o A..:? "align" >>= parseEnum aligns (def Imports.importAlign)) - <*> (o A..:? "list_align" >>= parseEnum listAligns (def Imports.listAlign)) - <*> (o A..:? "pad_module_names" A..!= def Imports.padModuleNames) - <*> (o A..:? "long_list_align" - >>= parseEnum longListAligns (def Imports.longListAlign)) - -- Note that padding has to be at least 1. Default is 4. - <*> (o A..:? "empty_list_align" - >>= parseEnum emptyListAligns (def Imports.emptyListAlign)) - <*> o A..:? "list_padding" A..!= def Imports.listPadding - <*> o A..:? "separate_lists" A..!= def Imports.separateLists - <*> o A..:? "space_surround" A..!= def Imports.spaceSurround) +parseImports config o = fmap (Imports.step columns) $ Imports.Options + <$> (o A..:? "align" >>= parseEnum aligns (def Imports.importAlign)) + <*> (o A..:? "list_align" >>= parseEnum listAligns (def Imports.listAlign)) + <*> (o A..:? "pad_module_names" A..!= def Imports.padModuleNames) + <*> (o A..:? "long_list_align" >>= parseEnum longListAligns (def Imports.longListAlign)) + <*> (o A..:? "empty_list_align" >>= parseEnum emptyListAligns (def Imports.emptyListAlign)) + -- Note that padding has to be at least 1. Default is 4. + <*> (o A..:? "list_padding" >>= maybe (pure $ def Imports.listPadding) parseListPadding) + <*> o A..:? "separate_lists" A..!= def Imports.separateLists + <*> o A..:? "space_surround" A..!= def Imports.spaceSurround where def f = f Imports.defaultOptions + columns = configColumns config + aligns = [ ("global", Imports.Global) , ("file", Imports.File) @@ -243,6 +290,7 @@ parseImports config o = Imports.step , ("with_module_name", Imports.WithModuleName) , ("with_alias", Imports.WithAlias) , ("after_alias", Imports.AfterAlias) + , ("repeat", Imports.Repeat) ] longListAligns = @@ -257,6 +305,11 @@ parseImports config o = Imports.step , ("right_after", Imports.RightAfter) ] + parseListPadding = \case + A.String "module_name" -> pure Imports.LPModuleName + A.Number n | n >= 1 -> pure $ Imports.LPConstant (truncate n) + v -> A.typeMismatch "'module_name' or >=1 number" v + -------------------------------------------------------------------------------- parseLanguagePragmas :: Config -> A.Object -> A.Parser Step parseLanguagePragmas config o = LanguagePragmas.step diff --git a/lib/Language/Haskell/Stylish/GHC.hs b/lib/Language/Haskell/Stylish/GHC.hs new file mode 100644 index 0000000..c99d4bf --- /dev/null +++ b/lib/Language/Haskell/Stylish/GHC.hs @@ -0,0 +1,103 @@ +{-# LANGUAGE LambdaCase #-} +{-# OPTIONS_GHC -Wno-missing-fields #-} +-- | Utility functions for working with the GHC AST +module Language.Haskell.Stylish.GHC + ( dropAfterLocated + , dropBeforeLocated + , dropBeforeAndAfter + -- * Unsafe getters + , unsafeGetRealSrcSpan + , getEndLineUnsafe + , getStartLineUnsafe + -- * Standard settings + , baseDynFlags + -- * Positions + , unLocated + -- * Outputable operators + , showOutputable + , compareOutputable + ) where + +-------------------------------------------------------------------------------- +import Data.Function (on) + +-------------------------------------------------------------------------------- +import DynFlags (Settings (..), defaultDynFlags) +import qualified DynFlags as GHC +import FileSettings (FileSettings (..)) +import GHC.Fingerprint (fingerprint0) +import GHC.Platform +import GHC.Version (cProjectVersion) +import GhcNameVersion (GhcNameVersion (..)) +import qualified Outputable as GHC +import PlatformConstants (PlatformConstants (..)) +import SrcLoc (GenLocated (..), Located, RealLocated, + RealSrcSpan, SrcSpan (..), srcSpanEndLine, + srcSpanStartLine) +import ToolSettings (ToolSettings (..)) + +unsafeGetRealSrcSpan :: Located a -> RealSrcSpan +unsafeGetRealSrcSpan = \case + (L (RealSrcSpan s) _) -> s + _ -> error "could not get source code location" + +getStartLineUnsafe :: Located a -> Int +getStartLineUnsafe = srcSpanStartLine . unsafeGetRealSrcSpan + +getEndLineUnsafe :: Located a -> Int +getEndLineUnsafe = srcSpanEndLine . unsafeGetRealSrcSpan + +dropAfterLocated :: Maybe (Located a) -> [RealLocated b] -> [RealLocated b] +dropAfterLocated loc xs = case loc of + Just (L (RealSrcSpan rloc) _) -> + filter (\(L x _) -> srcSpanEndLine rloc >= srcSpanStartLine x) xs + _ -> xs + +dropBeforeLocated :: Maybe (Located a) -> [RealLocated b] -> [RealLocated b] +dropBeforeLocated loc xs = case loc of + Just (L (RealSrcSpan rloc) _) -> + filter (\(L x _) -> srcSpanStartLine rloc <= srcSpanEndLine x) xs + _ -> xs + +dropBeforeAndAfter :: Located a -> [RealLocated b] -> [RealLocated b] +dropBeforeAndAfter loc = dropBeforeLocated (Just loc) . dropAfterLocated (Just loc) + +baseDynFlags :: GHC.DynFlags +baseDynFlags = defaultDynFlags fakeSettings llvmConfig + where + fakeSettings = GHC.Settings + { sGhcNameVersion = GhcNameVersion "stylish-haskell" cProjectVersion + , sFileSettings = FileSettings {} + , sToolSettings = ToolSettings + { toolSettings_opt_P_fingerprint = fingerprint0, + toolSettings_pgm_F = "" + } + , sPlatformConstants = PlatformConstants + { pc_DYNAMIC_BY_DEFAULT = False + , pc_WORD_SIZE = 8 + } + , sTargetPlatform = Platform + { platformMini = PlatformMini + { platformMini_arch = ArchUnknown + , platformMini_os = OSUnknown + } + , platformWordSize = PW8 + , platformUnregisterised = True + , platformHasIdentDirective = False + , platformHasSubsectionsViaSymbols = False + , platformIsCrossCompiling = False + } + , sPlatformMisc = PlatformMisc {} + , sRawSettings = [] + } + + llvmConfig = GHC.LlvmConfig [] [] + +unLocated :: Located a -> a +unLocated (L _ a) = a + +showOutputable :: GHC.Outputable a => a -> String +showOutputable = GHC.showPpr baseDynFlags + +compareOutputable :: GHC.Outputable a => a -> a -> Ordering +compareOutputable = compare `on` showOutputable diff --git a/lib/Language/Haskell/Stylish/Module.hs b/lib/Language/Haskell/Stylish/Module.hs new file mode 100644 index 0000000..3dbebe0 --- /dev/null +++ b/lib/Language/Haskell/Stylish/Module.hs @@ -0,0 +1,283 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE DeriveDataTypeable #-} +module Language.Haskell.Stylish.Module + ( -- * Data types + Module (..) + , ModuleHeader + , Import + , Decls + , Comments + , Lines + , makeModule + + -- * Getters + , moduleHeader + , moduleImports + , moduleImportGroups + , moduleDecls + , moduleComments + , moduleLanguagePragmas + , queryModule + , groupByLine + + -- * Imports + , canMergeImport + , mergeModuleImport + + -- * Annotations + , lookupAnnotation + + -- * Internal API getters + , rawComments + , rawImport + , rawModuleAnnotations + , rawModuleDecls + , rawModuleExports + , rawModuleHaddocks + , rawModuleName + ) where + +-------------------------------------------------------------------------------- +import Data.Function ((&), on) +import Data.Functor ((<&>)) +import Data.Generics (Typeable, everything, mkQ) +import Data.Maybe (mapMaybe) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.List (nubBy, sort) +import Data.List.NonEmpty (NonEmpty (..), nonEmpty) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Data (Data) + +-------------------------------------------------------------------------------- +import qualified ApiAnnotation as GHC +import qualified Lexer as GHC +import GHC.Hs (ImportDecl(..), ImportDeclQualifiedStyle(..)) +import qualified GHC.Hs as GHC +import GHC.Hs.Extension (GhcPs) +import GHC.Hs.Decls (LHsDecl) +import Outputable (Outputable) +import SrcLoc (GenLocated(..), RealLocated) +import SrcLoc (RealSrcSpan(..), SrcSpan(..)) +import SrcLoc (Located) +import qualified SrcLoc as GHC +import qualified Module as GHC + +-------------------------------------------------------------------------------- +import Language.Haskell.Stylish.GHC + +-------------------------------------------------------------------------------- +type Lines = [String] + + +-------------------------------------------------------------------------------- +-- | Concrete module type +data Module = Module + { parsedComments :: [GHC.RealLocated GHC.AnnotationComment] + , parsedAnnotations :: [(GHC.ApiAnnKey, [GHC.SrcSpan])] + , parsedAnnotSrcs :: Map RealSrcSpan [GHC.AnnKeywordId] + , parsedModule :: GHC.Located (GHC.HsModule GhcPs) + } deriving (Data) + +-- | Declarations in module +newtype Decls = Decls [LHsDecl GhcPs] + +-- | Import declaration in module +newtype Import = Import { unImport :: ImportDecl GhcPs } + deriving newtype (Outputable) + +-- | Returns true if the two import declarations can be merged +canMergeImport :: Import -> Import -> Bool +canMergeImport (Import i0) (Import i1) = and $ fmap (\f -> f i0 i1) + [ (==) `on` unLocated . ideclName + , (==) `on` ideclPkgQual + , (==) `on` ideclSource + , hasMergableQualified `on` ideclQualified + , (==) `on` ideclImplicit + , (==) `on` fmap unLocated . ideclAs + , (==) `on` fmap fst . ideclHiding -- same 'hiding' flags + ] + where + hasMergableQualified QualifiedPre QualifiedPost = True + hasMergableQualified QualifiedPost QualifiedPre = True + hasMergableQualified q0 q1 = q0 == q1 + +instance Eq Import where + i0 == i1 = canMergeImport i0 i1 && hasSameImports (unImport i0) (unImport i1) + where + hasSameImports = (==) `on` fmap snd . ideclHiding + +instance Ord Import where + compare (Import i0) (Import i1) = + ideclName i0 `compareOutputable` ideclName i1 <> + fmap showOutputable (ideclPkgQual i0) `compare` + fmap showOutputable (ideclPkgQual i1) <> + compareOutputable i0 i1 + +-- | Comments associated with module +newtype Comments = Comments [GHC.RealLocated GHC.AnnotationComment] + +-- | A module header is its name, exports and haddock docstring +data ModuleHeader = ModuleHeader + { name :: Maybe (GHC.Located GHC.ModuleName) + , exports :: Maybe (GHC.Located [GHC.LIE GhcPs]) + , haddocks :: Maybe GHC.LHsDocString + } + +-- | Create a module from GHC internal representations +makeModule :: GHC.PState -> GHC.Located (GHC.HsModule GHC.GhcPs) -> Module +makeModule pstate = Module comments annotations annotationMap + where + comments + = sort + . filterRealLocated + $ GHC.comment_q pstate ++ (GHC.annotations_comments pstate >>= snd) + + filterRealLocated = mapMaybe \case + GHC.L (GHC.RealSrcSpan s) e -> Just (GHC.L s e) + GHC.L (GHC.UnhelpfulSpan _) _ -> Nothing + + annotations + = GHC.annotations pstate + + annotationMap + = GHC.annotations pstate + & mapMaybe x + & Map.fromListWith (++) + + x = \case + ((RealSrcSpan rspan, annot), _) -> Just (rspan, [annot]) + _ -> Nothing + +-- | Get all declarations in module +moduleDecls :: Module -> Decls +moduleDecls = Decls . GHC.hsmodDecls . unLocated . parsedModule + +-- | Get comments in module +moduleComments :: Module -> Comments +moduleComments = Comments . parsedComments + +-- | Get module language pragmas +moduleLanguagePragmas :: Module -> [(RealSrcSpan, NonEmpty Text)] +moduleLanguagePragmas = mapMaybe toLanguagePragma . parsedComments + where + toLanguagePragma :: RealLocated GHC.AnnotationComment -> Maybe (RealSrcSpan, NonEmpty Text) + toLanguagePragma = \case + L pos (GHC.AnnBlockComment s) -> + Just (T.pack s) + >>= T.stripPrefix "{-#" + >>= T.stripSuffix "#-}" + <&> T.strip + <&> T.splitAt 8 -- length "LANGUAGE" + <&> fmap (T.splitOn ",") + <&> fmap (fmap T.strip) + <&> fmap (filter (not . T.null)) + >>= (\(T.toUpper . T.strip -> lang, xs) -> (lang,) <$> nonEmpty xs) + >>= (\(lang, nel) -> if lang == "LANGUAGE" then Just (pos, nel) else Nothing) + _ -> Nothing + +-- | Get module imports +moduleImports :: Module -> [Located Import] +moduleImports m + = parsedModule m + & unLocated + & GHC.hsmodImports + & fmap \(L pos i) -> L pos (Import i) + +-- | Get groups of imports from module +moduleImportGroups :: Module -> [NonEmpty (Located Import)] +moduleImportGroups = groupByLine unsafeGetRealSrcSpan . moduleImports + +-- The same logic as 'Language.Haskell.Stylish.Module.moduleImportGroups'. +groupByLine :: (a -> RealSrcSpan) -> [a] -> [NonEmpty a] +groupByLine f = go [] Nothing + where + go acc _ [] = ne acc + go acc mbCurrentLine (x:xs) = + let + lStart = GHC.srcSpanStartLine (f x) + lEnd = GHC.srcSpanEndLine (f x) in + case mbCurrentLine of + Just lPrevEnd | lPrevEnd + 1 < lStart + -> ne acc ++ go [x] (Just lEnd) xs + _ -> go (acc ++ [x]) (Just lEnd) xs + + ne [] = [] + ne (x : xs) = [x :| xs] + +-- | Merge two import declarations, keeping positions from the first +-- +-- As alluded, this highlights an issue with merging imports. The GHC +-- annotation comments aren't attached to any particular AST node. This +-- means that right now, we're manually reconstructing the attachment. By +-- merging two import declarations, we lose that mapping. +-- +-- It's not really a big deal if we consider that people don't usually +-- comment imports themselves. It _is_ however, systemic and it'd be better +-- if we processed comments beforehand and attached them to all AST nodes in +-- our own representation. +mergeModuleImport :: Located Import -> Located Import -> Located Import +mergeModuleImport (L p0 (Import i0)) (L _p1 (Import i1)) = + L p0 $ Import i0 { ideclHiding = newImportNames } + where + newImportNames = + case (ideclHiding i0, ideclHiding i1) of + (Just (b, L p imps0), Just (_, L _ imps1)) -> Just (b, L p (imps0 `merge` imps1)) + (Nothing, Nothing) -> Nothing + (Just x, Nothing) -> Just x + (Nothing, Just x) -> Just x + merge xs ys + = nubBy ((==) `on` showOutputable) (xs ++ ys) + +-- | Get module header +moduleHeader :: Module -> ModuleHeader +moduleHeader (Module _ _ _ (GHC.L _ m)) = ModuleHeader + { name = GHC.hsmodName m + , exports = GHC.hsmodExports m + , haddocks = GHC.hsmodHaddockModHeader m + } + +-- | Query for annotations associated with a 'SrcSpan' +lookupAnnotation :: SrcSpan -> Module -> [GHC.AnnKeywordId] +lookupAnnotation (RealSrcSpan rspan) m = Map.findWithDefault [] rspan (parsedAnnotSrcs m) +lookupAnnotation (UnhelpfulSpan _) _ = [] + +-- | Query the module AST using @f@ +queryModule :: Typeable a => (a -> [b]) -> Module -> [b] +queryModule f = everything (++) (mkQ [] f) . parsedModule + +-------------------------------------------------------------------------------- +-- | Getter for internal components in imports newtype +rawImport :: Import -> ImportDecl GhcPs +rawImport (Import i) = i + +-- | Getter for internal module name representation +rawModuleName :: ModuleHeader -> Maybe (GHC.Located GHC.ModuleName) +rawModuleName = name + +-- | Getter for internal module exports representation +rawModuleExports :: ModuleHeader -> Maybe (GHC.Located [GHC.LIE GhcPs]) +rawModuleExports = exports + +-- | Getter for internal module haddocks representation +rawModuleHaddocks :: ModuleHeader -> Maybe GHC.LHsDocString +rawModuleHaddocks = haddocks + +-- | Getter for internal module decls representation +rawModuleDecls :: Decls -> [LHsDecl GhcPs] +rawModuleDecls (Decls xs) = xs + +-- | Getter for internal module comments representation +rawComments :: Comments -> [GHC.RealLocated GHC.AnnotationComment] +rawComments (Comments xs) = xs + +-- | Getter for internal module annotation representation +rawModuleAnnotations :: Module -> [(GHC.ApiAnnKey, [GHC.SrcSpan])] +rawModuleAnnotations = parsedAnnotations diff --git a/lib/Language/Haskell/Stylish/Ordering.hs b/lib/Language/Haskell/Stylish/Ordering.hs new file mode 100644 index 0000000..1a05eb4 --- /dev/null +++ b/lib/Language/Haskell/Stylish/Ordering.hs @@ -0,0 +1,61 @@ +-------------------------------------------------------------------------------- +-- | There are a number of steps that sort items: 'Imports' and 'ModuleHeader', +-- and maybe more in the future. This module provides consistent sorting +-- utilities. +{-# LANGUAGE LambdaCase #-} +module Language.Haskell.Stylish.Ordering + ( compareLIE + , compareWrappedName + , unwrapName + ) where + + +-------------------------------------------------------------------------------- +import Data.Char (isUpper) +import Data.Ord (comparing) +import GHC.Hs +import RdrName (RdrName) +import SrcLoc (unLoc) + + +-------------------------------------------------------------------------------- +import Language.Haskell.Stylish.GHC (showOutputable) +import Outputable (Outputable) + + +-------------------------------------------------------------------------------- +-- | NOTE: Can we get rid off this by adding a properly sorting newtype around +-- 'RdrName'? +compareLIE :: LIE GhcPs -> LIE GhcPs -> Ordering +compareLIE = comparing $ ieKey . unLoc + where + -- | The implementation is a bit hacky to get proper sorting for input specs: + -- constructors first, followed by functions, and then operators. + ieKey :: IE GhcPs -> (Int, String) + ieKey = \case + IEVar _ n -> nameKey n + IEThingAbs _ n -> nameKey n + IEThingAll _ n -> nameKey n + IEThingWith _ n _ _ _ -> nameKey n + IEModuleContents _ n -> nameKey n + _ -> (2, "") + + +-------------------------------------------------------------------------------- +compareWrappedName :: IEWrappedName RdrName -> IEWrappedName RdrName -> Ordering +compareWrappedName = comparing nameKey + + +-------------------------------------------------------------------------------- +unwrapName :: IEWrappedName n -> n +unwrapName (IEName n) = unLoc n +unwrapName (IEPattern n) = unLoc n +unwrapName (IEType n) = unLoc n + + +-------------------------------------------------------------------------------- +nameKey :: Outputable name => name -> (Int, String) +nameKey n = case showOutputable n of + o@('(' : _) -> (2, o) + o@(o0 : _) | isUpper o0 -> (0, o) + o -> (1, o) diff --git a/lib/Language/Haskell/Stylish/Parse.hs b/lib/Language/Haskell/Stylish/Parse.hs index 01def63..b416a32 100644 --- a/lib/Language/Haskell/Stylish/Parse.hs +++ b/lib/Language/Haskell/Stylish/Parse.hs @@ -1,35 +1,39 @@ +{-# LANGUAGE LambdaCase #-} -------------------------------------------------------------------------------- module Language.Haskell.Stylish.Parse - ( parseModule - ) where + ( parseModule + ) where -------------------------------------------------------------------------------- -import Data.List (isPrefixOf, nub) +import Data.Function ((&)) import Data.Maybe (fromMaybe, listToMaybe) -import qualified Language.Haskell.Exts as H - +import System.IO.Unsafe (unsafePerformIO) -------------------------------------------------------------------------------- -import Language.Haskell.Stylish.Config -import Language.Haskell.Stylish.Step - +import Bag (bagToList) +import qualified DynFlags as GHC +import qualified ErrUtils as GHC +import FastString (mkFastString) +import qualified GHC.Hs as GHC +import qualified GHC.LanguageExtensions as GHC +import qualified HeaderInfo as GHC +import qualified HscTypes as GHC +import Lexer (ParseResult (..)) +import Lexer (mkPState, unP) +import qualified Lexer as GHC +import qualified Panic as GHC +import qualified Parser as GHC +import SrcLoc (mkRealSrcLoc) +import qualified SrcLoc as GHC +import StringBuffer (stringToStringBuffer) +import qualified StringBuffer as GHC -------------------------------------------------------------------------------- --- | Syntax-related language extensions are always enabled for parsing. Since we --- can't authoritatively know which extensions are enabled at compile-time, we --- should try not to throw errors when parsing any GHC-accepted code. -defaultExtensions :: [H.Extension] -defaultExtensions = map H.EnableExtension - [ H.GADTs - , H.HereDocuments - , H.KindSignatures - , H.NewQualifiedOperators - , H.PatternGuards - , H.StandaloneDeriving - , H.UnicodeSyntax - ] +import Language.Haskell.Stylish.GHC (baseDynFlags) +import Language.Haskell.Stylish.Module +type Extensions = [String] -------------------------------------------------------------------------------- -- | Filter out lines which use CPP macros @@ -42,15 +46,6 @@ unCpp = unlines . go False . lines nextMultiline = isCpp && not (null x) && last x == '\\' in (if isCpp then "" else x) : go nextMultiline xs - --------------------------------------------------------------------------------- --- | Remove shebang lines -unShebang :: String -> String -unShebang str = - let (shebangs, other) = break (not . ("#!" `isPrefixOf`)) (lines str) in - unlines $ map (const "") shebangs ++ other - - -------------------------------------------------------------------------------- -- | If the given string is prefixed with an UTF-8 Byte Order Mark, drop it -- because haskell-src-exts can't handle it. @@ -60,32 +55,69 @@ dropBom str = str -------------------------------------------------------------------------------- --- | Abstraction over HSE's parsing +-- | Abstraction over GHC lib's parsing parseModule :: Extensions -> Maybe FilePath -> String -> Either String Module -parseModule extraExts mfp string = do - -- Determine the extensions: those specified in the file and the extra ones - let noPrefixes = unShebang . dropBom $ string - extraExts' = map H.classifyExtension extraExts - (lang, fileExts) = fromMaybe (Nothing, []) $ H.readExtensions noPrefixes - exts = nub $ fileExts ++ extraExts' ++ defaultExtensions - - -- Parsing options... - fp = fromMaybe "<unknown>" mfp - mode = H.defaultParseMode - { H.extensions = exts - , H.fixities = Nothing - , H.baseLanguage = case lang of - Nothing -> H.baseLanguage H.defaultParseMode - Just l -> l - } - - -- Preprocessing - processed = if H.EnableExtension H.CPP `elem` exts - then unCpp noPrefixes - else noPrefixes - - case H.parseModuleWithComments mode processed of - H.ParseOk md -> return md - err -> Left $ - "Language.Haskell.Stylish.Parse.parseModule: could not parse " ++ - fp ++ ": " ++ show err +parseModule exts fp string = + parsePragmasIntoDynFlags baseDynFlags userExtensions filePath string >>= \dynFlags -> + dropBom string + & removeCpp dynFlags + & runParser dynFlags + & toModule dynFlags + where + toModule :: GHC.DynFlags -> GHC.ParseResult (GHC.Located (GHC.HsModule GHC.GhcPs)) -> Either String Module + toModule dynFlags res = case res of + POk ps m -> + Right (makeModule ps m) + PFailed failureState -> + let + withFileName x = maybe "" (<> ": ") fp <> x + in + Left . withFileName . unlines . getParserStateErrors dynFlags $ failureState + + removeCpp dynFlags s = + if GHC.xopt GHC.Cpp dynFlags then unCpp s + else s + + userExtensions = + fmap toLocatedExtensionFlag ("Haskell2010" : exts) -- FIXME: do we need `Haskell2010` here? + + toLocatedExtensionFlag flag + = "-X" <> flag + & GHC.L GHC.noSrcSpan + + getParserStateErrors dynFlags state + = GHC.getErrorMessages state dynFlags + & bagToList + & fmap (\errMsg -> show (GHC.errMsgSpan errMsg) <> ": " <> show errMsg) + + filePath = + fromMaybe "<interactive>" fp + + runParser :: GHC.DynFlags -> String -> GHC.ParseResult (GHC.Located (GHC.HsModule GHC.GhcPs)) + runParser flags str = + let + filename = mkFastString filePath + parseState = mkPState flags (stringToStringBuffer str) (mkRealSrcLoc filename 1 1) + in + unP GHC.parseModule parseState + +-- | Parse 'DynFlags' from the extra options +-- +-- /Note:/ this function would be IO, but we're not using any of the internal +-- features that constitute side effectful computation. So I think it's fine +-- if we run this to avoid changing the interface too much. +parsePragmasIntoDynFlags :: + GHC.DynFlags + -> [GHC.Located String] + -> FilePath + -> String + -> Either String GHC.DynFlags +{-# NOINLINE parsePragmasIntoDynFlags #-} +parsePragmasIntoDynFlags originalFlags extraOpts filepath str = unsafePerformIO $ catchErrors $ do + let opts = GHC.getOptions originalFlags (GHC.stringToStringBuffer str) filepath + (parsedFlags, _invalidFlags, _warnings) <- GHC.parseDynamicFilePragma originalFlags (opts <> extraOpts) + -- FIXME: have a look at 'leftovers' since it should be empty + return $ Right $ parsedFlags `GHC.gopt_set` GHC.Opt_KeepRawTokenStream + where + catchErrors act = GHC.handleGhcException reportErr (GHC.handleSourceError reportErr act) + reportErr e = return $ Left (show e) diff --git a/lib/Language/Haskell/Stylish/Printer.hs b/lib/Language/Haskell/Stylish/Printer.hs new file mode 100644 index 0000000..a7ddf5e --- /dev/null +++ b/lib/Language/Haskell/Stylish/Printer.hs @@ -0,0 +1,458 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DoAndIfThenElse #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +module Language.Haskell.Stylish.Printer + ( Printer(..) + , PrinterConfig(..) + , PrinterState(..) + + -- * Alias + , P + + -- * Functions to use the printer + , runPrinter + , runPrinter_ + + -- ** Combinators + , comma + , dot + , getAnnot + , getCurrentLine + , getCurrentLineLength + , getDocstrPrev + , newline + , parenthesize + , peekNextCommentPos + , prefix + , putComment + , putEolComment + , putOutputable + , putAllSpanComments + , putCond + , putType + , putRdrName + , putText + , removeCommentTo + , removeCommentToEnd + , removeLineComment + , sep + , groupAttachedComments + , space + , spaces + , suffix + , pad + + -- ** Advanced combinators + , withColumns + , modifyCurrentLine + , wrapping + ) where + +-------------------------------------------------------------------------------- +import Prelude hiding (lines) + +-------------------------------------------------------------------------------- +import ApiAnnotation (AnnKeywordId(..), AnnotationComment(..)) +import GHC.Hs.Extension (GhcPs, NoExtField(..)) +import GHC.Hs.Types (HsType(..)) +import Module (ModuleName, moduleNameString) +import RdrName (RdrName(..)) +import SrcLoc (GenLocated(..), RealLocated) +import SrcLoc (Located, SrcSpan(..)) +import SrcLoc (srcSpanStartLine, srcSpanEndLine) +import Outputable (Outputable) + +-------------------------------------------------------------------------------- +import Control.Monad (forM_, replicateM_) +import Control.Monad.Reader (MonadReader, ReaderT(..), asks, local) +import Control.Monad.State (MonadState, State) +import Control.Monad.State (runState) +import Control.Monad.State (get, gets, modify, put) +import Data.Foldable (find) +import Data.Functor ((<&>)) +import Data.List (delete, isPrefixOf) +import Data.List.NonEmpty (NonEmpty(..)) + +-------------------------------------------------------------------------------- +import Language.Haskell.Stylish.Module (Module, Lines, lookupAnnotation) +import Language.Haskell.Stylish.GHC (showOutputable, unLocated) + +-- | Shorthand for 'Printer' monad +type P = Printer + +-- | Printer that keeps state of file +newtype Printer a = Printer (ReaderT PrinterConfig (State PrinterState) a) + deriving (Applicative, Functor, Monad, MonadReader PrinterConfig, MonadState PrinterState) + +-- | Configuration for printer, currently empty +data PrinterConfig = PrinterConfig + { columns :: !(Maybe Int) + } + +-- | State of printer +data PrinterState = PrinterState + { lines :: !Lines + , linePos :: !Int + , currentLine :: !String + , pendingComments :: ![RealLocated AnnotationComment] + , parsedModule :: !Module + } + +-- | Run printer to get printed lines out of module as well as return value of monad +runPrinter :: PrinterConfig -> [RealLocated AnnotationComment] -> Module -> Printer a -> (a, Lines) +runPrinter cfg comments m (Printer printer) = + let + (a, PrinterState parsedLines _ startedLine _ _) = runReaderT printer cfg `runState` PrinterState [] 0 "" comments m + in + (a, parsedLines <> if startedLine == [] then [] else [startedLine]) + +-- | Run printer to get printed lines only +runPrinter_ :: PrinterConfig -> [RealLocated AnnotationComment] -> Module -> Printer a -> Lines +runPrinter_ cfg comments m printer = snd (runPrinter cfg comments m printer) + +-- | Print text +putText :: String -> P () +putText txt = do + l <- gets currentLine + modify \s -> s { currentLine = l <> txt } + +-- | Check condition post action, and use fallback if false +putCond :: (PrinterState -> Bool) -> P b -> P b -> P b +putCond p action fallback = do + prevState <- get + res <- action + currState <- get + if p currState then pure res + else put prevState >> fallback + +-- | Print an 'Outputable' +putOutputable :: Outputable a => a -> P () +putOutputable = putText . showOutputable + +-- | Put all comments that has positions within 'SrcSpan' and separate by +-- passed @P ()@ +putAllSpanComments :: P () -> SrcSpan -> P () +putAllSpanComments suff = \case + UnhelpfulSpan _ -> pure () + RealSrcSpan rspan -> do + cmts <- removeComments \(L rloc _) -> + srcSpanStartLine rloc >= srcSpanStartLine rspan && + srcSpanEndLine rloc <= srcSpanEndLine rspan + + forM_ cmts (\c -> putComment c >> suff) + +-- | Print any comment +putComment :: AnnotationComment -> P () +putComment = \case + AnnLineComment s -> putText s + AnnDocCommentNext s -> putText s + AnnDocCommentPrev s -> putText s + AnnDocCommentNamed s -> putText s + AnnDocSection _ s -> putText s + AnnDocOptions s -> putText s + AnnBlockComment s -> putText s + +-- | Given the current start line of 'SrcSpan', remove and put EOL comment for same line +putEolComment :: SrcSpan -> P () +putEolComment = \case + RealSrcSpan rspan -> do + cmt <- removeComment \case + L rloc (AnnLineComment s) -> + and + [ srcSpanStartLine rspan == srcSpanStartLine rloc + , not ("-- ^" `isPrefixOf` s) + , not ("-- |" `isPrefixOf` s) + ] + _ -> False + forM_ cmt (\c -> space >> putComment c) + UnhelpfulSpan _ -> pure () + +-- | Print a 'RdrName' +putRdrName :: Located RdrName -> P () +putRdrName (L pos n) = case n of + Unqual name -> do + annots <- getAnnot pos + if AnnOpenP `elem` annots then do + putText "(" + putText (showOutputable name) + putText ")" + else if AnnBackquote `elem` annots then do + putText "`" + putText (showOutputable name) + putText "`" + else if AnnSimpleQuote `elem` annots then do + putText "'" + putText (showOutputable name) + else + putText (showOutputable name) + Qual modulePrefix name -> + putModuleName modulePrefix >> dot >> putText (showOutputable name) + Orig _ name -> + putText (showOutputable name) + Exact name -> + putText (showOutputable name) + +-- | Print module name +putModuleName :: ModuleName -> P () +putModuleName = putText . moduleNameString + +-- | Print type +putType :: Located (HsType GhcPs) -> P () +putType ltp = case unLocated ltp of + HsFunTy NoExtField argTp funTp -> do + putOutputable argTp + space + putText "->" + space + putType funTp + HsAppTy NoExtField t1 t2 -> + putType t1 >> space >> putType t2 + HsExplicitListTy NoExtField _ xs -> do + putText "'[" + sep + (comma >> space) + (fmap putType xs) + putText "]" + HsExplicitTupleTy NoExtField xs -> do + putText "'(" + sep + (comma >> space) + (fmap putType xs) + putText ")" + HsOpTy NoExtField lhs op rhs -> do + putType lhs + space + putRdrName op + space + putType rhs + HsTyVar NoExtField _ rdrName -> + putRdrName rdrName + HsTyLit _ tp -> + putOutputable tp + HsParTy _ tp -> do + putText "(" + putType tp + putText ")" + HsTupleTy NoExtField _ xs -> do + putText "(" + sep + (comma >> space) + (fmap putType xs) + putText ")" + HsForAllTy NoExtField _ _ _ -> + putOutputable ltp + HsQualTy NoExtField _ _ -> + putOutputable ltp + HsAppKindTy _ _ _ -> + putOutputable ltp + HsListTy _ _ -> + putOutputable ltp + HsSumTy _ _ -> + putOutputable ltp + HsIParamTy _ _ _ -> + putOutputable ltp + HsKindSig _ _ _ -> + putOutputable ltp + HsStarTy _ _ -> + putOutputable ltp + HsSpliceTy _ _ -> + putOutputable ltp + HsDocTy _ _ _ -> + putOutputable ltp + HsBangTy _ _ _ -> + putOutputable ltp + HsRecTy _ _ -> + putOutputable ltp + HsWildCardTy _ -> + putOutputable ltp + XHsType _ -> + putOutputable ltp + +-- | Get a docstring on the start line of 'SrcSpan' that is a @-- ^@ comment +getDocstrPrev :: SrcSpan -> P (Maybe AnnotationComment) +getDocstrPrev = \case + UnhelpfulSpan _ -> pure Nothing + RealSrcSpan rspan -> do + removeComment \case + L rloc (AnnLineComment s) -> + and + [ srcSpanStartLine rspan == srcSpanStartLine rloc + , "-- ^" `isPrefixOf` s + ] + _ -> False + +-- | Print a newline +newline :: P () +newline = do + l <- gets currentLine + modify \s -> s { currentLine = "", linePos = 0, lines = lines s <> [l] } + +-- | Print a space +space :: P () +space = putText " " + +-- | Print a number of spaces +spaces :: Int -> P () +spaces i = replicateM_ i space + +-- | Print a dot +dot :: P () +dot = putText "." + +-- | Print a comma +comma :: P () +comma = putText "," + +-- | Add parens around a printed action +parenthesize :: P a -> P a +parenthesize action = putText "(" *> action <* putText ")" + +-- | Add separator between each element of the given printers +sep :: P a -> [P a] -> P () +sep _ [] = pure () +sep s (first : rest) = first >> forM_ rest ((>>) s) + +-- | Prefix a printer with another one +prefix :: P a -> P b -> P b +prefix pa pb = pa >> pb + +-- | Suffix a printer with another one +suffix :: P a -> P b -> P a +suffix pa pb = pb >> pa + +-- | Indent to a given number of spaces. If the current line already exceeds +-- that number in length, nothing happens. +pad :: Int -> P () +pad n = do + len <- length <$> getCurrentLine + spaces $ n - len + +-- | Gets comment on supplied 'line' and removes it from the state +removeLineComment :: Int -> P (Maybe AnnotationComment) +removeLineComment line = + removeComment (\(L rloc _) -> srcSpanStartLine rloc == line) + +-- | Removes comments from the state up to start line of 'SrcSpan' and returns +-- the ones that were removed +removeCommentTo :: SrcSpan -> P [AnnotationComment] +removeCommentTo = \case + UnhelpfulSpan _ -> pure [] + RealSrcSpan rspan -> removeCommentTo' (srcSpanStartLine rspan) + +-- | Removes comments from the state up to end line of 'SrcSpan' and returns +-- the ones that were removed +removeCommentToEnd :: SrcSpan -> P [AnnotationComment] +removeCommentToEnd = \case + UnhelpfulSpan _ -> pure [] + RealSrcSpan rspan -> removeCommentTo' (srcSpanEndLine rspan) + +-- | Removes comments to the line number given and returns the ones removed +removeCommentTo' :: Int -> P [AnnotationComment] +removeCommentTo' line = + removeComment (\(L rloc _) -> srcSpanStartLine rloc < line) >>= \case + Nothing -> pure [] + Just c -> do + rest <- removeCommentTo' line + pure (c : rest) + +-- | Removes comments from the state while given predicate 'p' is true +removeComments :: (RealLocated AnnotationComment -> Bool) -> P [AnnotationComment] +removeComments p = + removeComment p >>= \case + Just c -> do + rest <- removeComments p + pure (c : rest) + Nothing -> pure [] + +-- | Remove a comment from the state given predicate 'p' +removeComment :: (RealLocated AnnotationComment -> Bool) -> P (Maybe AnnotationComment) +removeComment p = do + comments <- gets pendingComments + + let + foundComment = + find p comments + + newPendingComments = + maybe comments (`delete` comments) foundComment + + modify \s -> s { pendingComments = newPendingComments } + pure $ fmap (\(L _ c) -> c) foundComment + +-- | Get all annotations for 'SrcSpan' +getAnnot :: SrcSpan -> P [AnnKeywordId] +getAnnot spn = gets (lookupAnnotation spn . parsedModule) + +-- | Get current line +getCurrentLine :: P String +getCurrentLine = gets currentLine + +-- | Get current line length +getCurrentLineLength :: P Int +getCurrentLineLength = fmap length getCurrentLine + +-- | Peek at the next comment in the state +peekNextCommentPos :: P (Maybe SrcSpan) +peekNextCommentPos = do + gets pendingComments <&> \case + (L next _ : _) -> Just (RealSrcSpan next) + [] -> Nothing + +-- | Get attached comments belonging to '[Located a]' given +groupAttachedComments :: [Located a] -> P [([AnnotationComment], NonEmpty (Located a))] +groupAttachedComments = go + where + go :: [Located a] -> P [([AnnotationComment], NonEmpty (Located a))] + go (L rspan x : xs) = do + comments <- removeCommentTo rspan + nextGroupStartM <- peekNextCommentPos + + let + sameGroupOf = maybe xs \nextGroupStart -> + takeWhile (\(L p _)-> p < nextGroupStart) xs + + restOf = maybe [] \nextGroupStart -> + dropWhile (\(L p _) -> p <= nextGroupStart) xs + + restGroups <- go (restOf nextGroupStartM) + pure $ (comments, L rspan x :| sameGroupOf nextGroupStartM) : restGroups + + go _ = pure [] + +modifyCurrentLine :: (String -> String) -> P () +modifyCurrentLine f = do + s0 <- get + put s0 {currentLine = f $ currentLine s0} + +wrapping + :: P a -- ^ First printer to run + -> P a -- ^ Printer to run if first printer violates max columns + -> P a -- ^ Result of either the first or the second printer +wrapping p1 p2 = do + maxCols <- asks columns + case maxCols of + -- No wrapping + Nothing -> p1 + Just c -> do + s0 <- get + x <- p1 + s1 <- get + if length (currentLine s1) <= c + -- No need to wrap + then pure x + else do + put s0 + y <- p2 + s2 <- get + if length (currentLine s1) == length (currentLine s2) + -- Wrapping didn't help! + then put s1 >> pure x + -- Wrapped + else pure y + +withColumns :: Maybe Int -> P a -> P a +withColumns c = local $ \pc -> pc {columns = c} diff --git a/lib/Language/Haskell/Stylish/Step.hs b/lib/Language/Haskell/Stylish/Step.hs index e5f3424..c2cfc70 100644 --- a/lib/Language/Haskell/Stylish/Step.hs +++ b/lib/Language/Haskell/Stylish/Step.hs @@ -1,24 +1,13 @@ -------------------------------------------------------------------------------- module Language.Haskell.Stylish.Step ( Lines - , Module , Step (..) , makeStep ) where -------------------------------------------------------------------------------- -import qualified Language.Haskell.Exts as H - - --------------------------------------------------------------------------------- -type Lines = [String] - - --------------------------------------------------------------------------------- --- | Concrete module type -type Module = (H.Module H.SrcSpanInfo, [H.Comment]) - +import Language.Haskell.Stylish.Module -------------------------------------------------------------------------------- data Step = Step @@ -26,7 +15,6 @@ data Step = Step , stepFilter :: Lines -> Module -> Lines } - -------------------------------------------------------------------------------- makeStep :: String -> (Lines -> Module -> Lines) -> Step makeStep = Step diff --git a/lib/Language/Haskell/Stylish/Step/Data.hs b/lib/Language/Haskell/Stylish/Step/Data.hs index 1f7732b..77d12a0 100644 --- a/lib/Language/Haskell/Stylish/Step/Data.hs +++ b/lib/Language/Haskell/Stylish/Step/Data.hs @@ -1,126 +1,546 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DoAndIfThenElse #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RecordWildCards #-} +module Language.Haskell.Stylish.Step.Data + ( Config(..) + , defaultConfig -module Language.Haskell.Stylish.Step.Data where + , Indent(..) + , MaxColumns(..) + , step + ) where -import Data.List (find, intercalate) -import Data.Maybe (fromMaybe, maybeToList) -import qualified Language.Haskell.Exts as H -import Language.Haskell.Exts.Comments +-------------------------------------------------------------------------------- +import Prelude hiding (init) + +-------------------------------------------------------------------------------- +import Control.Monad (forM_, unless, when) +import Data.Function ((&)) +import Data.Functor ((<&>)) +import Data.List (sortBy) +import Data.Maybe (listToMaybe) + +-------------------------------------------------------------------------------- +import ApiAnnotation (AnnotationComment) +import BasicTypes (LexicalFixity (..)) +import GHC.Hs.Decls (ConDecl (..), + DerivStrategy (..), + HsDataDefn (..), HsDecl (..), + HsDerivingClause (..), + NewOrData (..), + TyClDecl (..)) +import GHC.Hs.Extension (GhcPs, NoExtField (..), + noExtCon) +import GHC.Hs.Types (ConDeclField (..), + ForallVisFlag (..), + HsConDetails (..), HsContext, + HsImplicitBndrs (..), + HsTyVarBndr (..), + HsType (..), LHsQTyVars (..)) +import RdrName (RdrName) +import SrcLoc (GenLocated (..), Located, + RealLocated) + +-------------------------------------------------------------------------------- import Language.Haskell.Stylish.Block import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.GHC +import Language.Haskell.Stylish.Module +import Language.Haskell.Stylish.Printer import Language.Haskell.Stylish.Step -import Language.Haskell.Stylish.Util -import Prelude hiding (init) data Indent = SameLine | Indent !Int - deriving (Show) + deriving (Show, Eq) + +data MaxColumns + = MaxColumns !Int + | NoMaxColumns + deriving (Show, Eq) data Config = Config - { cEquals :: !Indent + { cEquals :: !Indent -- ^ Indent between type constructor and @=@ sign (measured from column 0) - , cFirstField :: !Indent + , cFirstField :: !Indent -- ^ Indent between data constructor and @{@ line (measured from column with data constructor name) - , cFieldComment :: !Int + , cFieldComment :: !Int -- ^ Indent between column with @{@ and start of field line comment (this line has @cFieldComment = 2@) - , cDeriving :: !Int + , cDeriving :: !Int -- ^ Indent before @deriving@ lines (measured from column 0) + , cBreakEnums :: !Bool + -- ^ Break enums by newlines and follow the above rules + , cBreakSingleConstructors :: !Bool + -- ^ Break single constructors when enabled, e.g. @Indent 2@ will not cause newline after @=@ + , cVia :: !Indent + -- ^ Indentation between @via@ clause and start of deriving column start + , cCurriedContext :: !Bool + -- ^ If true, use curried context. E.g: @allValues :: Enum a => Bounded a => Proxy a -> [a]@ + , cSortDeriving :: !Bool + -- ^ If true, will sort type classes in a @deriving@ list. + , cMaxColumns :: !MaxColumns } deriving (Show) -datas :: H.Module l -> [H.Decl l] -datas (H.Module _ _ _ _ decls) = decls -datas _ = [] - -type ChangeLine = Change String +-- | TODO: pass in MaxColumns? +defaultConfig :: Config +defaultConfig = Config + { cEquals = Indent 4 + , cFirstField = Indent 4 + , cFieldComment = 2 + , cDeriving = 4 + , cBreakEnums = True + , cBreakSingleConstructors = False + , cVia = Indent 4 + , cSortDeriving = True + , cMaxColumns = NoMaxColumns + , cCurriedContext = False + } step :: Config -> Step -step cfg = makeStep "Data" (step' cfg) - -step' :: Config -> Lines -> Module -> Lines -step' cfg ls (module', allComments) = applyChanges changes ls +step cfg = makeStep "Data" \ls m -> applyChanges (changes m) ls where - datas' = datas $ fmap linesFromSrcSpan module' - changes = datas' >>= maybeToList . changeDecl allComments cfg + changes :: Module -> [ChangeLine] + changes m = fmap (formatDataDecl cfg m) (dataDecls m) + + dataDecls :: Module -> [Located DataDecl] + dataDecls = queryModule \case + L pos (TyClD _ (DataDecl _ name tvars fixity defn)) -> pure . L pos $ MkDataDecl + { dataDeclName = name + , dataTypeVars = tvars + , dataDefn = defn + , dataFixity = fixity + } + _ -> [] + +type ChangeLine = Change String -findCommentOnLine :: LineBlock -> [Comment] -> Maybe Comment -findCommentOnLine lb = find commentOnLine +formatDataDecl :: Config -> Module -> Located DataDecl -> ChangeLine +formatDataDecl cfg@Config{..} m ldecl@(L declPos decl) = + change originalDeclBlock (const printedDecl) where - commentOnLine (Comment _ (H.SrcSpan _ start _ end _) _) = - blockStart lb == start && blockEnd lb == end + relevantComments :: [RealLocated AnnotationComment] + relevantComments + = moduleComments m + & rawComments + & dropBeforeAndAfter ldecl + + defn = dataDefn decl + + originalDeclBlock = + Block (getStartLineUnsafe ldecl) (getEndLineUnsafe ldecl) + + printerConfig = PrinterConfig + { columns = case cMaxColumns of + NoMaxColumns -> Nothing + MaxColumns n -> Just n + } + + printedDecl = runPrinter_ printerConfig relevantComments m do + putText (newOrData decl) + space + putName decl + + when (isGADT decl) (space >> putText "where") + + when (hasConstructors decl) do + breakLineBeforeEq <- case (cEquals, cFirstField) of + (_, Indent x) | isEnum decl && cBreakEnums -> do + putEolComment declPos + newline >> spaces x + pure True + (_, _) | not (isNewtype decl) && singleConstructor decl && not cBreakSingleConstructors -> + False <$ space + (Indent x, _) + | isEnum decl && not cBreakEnums -> False <$ space + | otherwise -> do + putEolComment declPos + newline >> spaces x + pure True + (SameLine, _) -> False <$ space + + lineLengthAfterEq <- fmap (+2) getCurrentLineLength + + if isEnum decl && not cBreakEnums then + putText "=" >> space >> putUnbrokenEnum cfg decl + else if isNewtype decl then + putText "=" >> space >> forM_ (dd_cons defn) (putNewtypeConstructor cfg) + else + case dd_cons defn of + [] -> pure () + lcon@(L pos _) : consRest -> do + when breakLineBeforeEq do + removeCommentTo pos >>= mapM_ \c -> putComment c >> consIndent lineLengthAfterEq + + unless + (isGADT decl) + (putText "=" >> space) + + putConstructor cfg lineLengthAfterEq lcon + forM_ consRest \con@(L conPos _) -> do + unless (cFirstField == SameLine) do + removeCommentTo conPos >>= mapM_ \c -> consIndent lineLengthAfterEq >> putComment c + consIndent lineLengthAfterEq + + unless + (isGADT decl) + (putText "|" >> space) + + putConstructor cfg lineLengthAfterEq con + putEolComment conPos + + when (hasDeriving decl) do + if isEnum decl && not cBreakEnums then + space + else do + removeCommentTo (defn & dd_derivs & \(L pos _) -> pos) >>= + mapM_ \c -> newline >> spaces cDeriving >> putComment c + newline + spaces cDeriving + + sep (newline >> spaces cDeriving) $ defn & dd_derivs & \(L pos ds) -> ds <&> \d -> do + putAllSpanComments (newline >> spaces cDeriving) pos + putDeriving cfg d + + consIndent eqIndent = newline >> case (cEquals, cFirstField) of + (SameLine, SameLine) -> spaces (eqIndent - 2) + (SameLine, Indent y) -> spaces (eqIndent + y - 4) + (Indent x, Indent _) -> spaces x + (Indent x, SameLine) -> spaces x + +data DataDecl = MkDataDecl + { dataDeclName :: Located RdrName + , dataTypeVars :: LHsQTyVars GhcPs + , dataDefn :: HsDataDefn GhcPs + , dataFixity :: LexicalFixity + } + +putDeriving :: Config -> Located (HsDerivingClause GhcPs) -> P () +putDeriving Config{..} (L pos clause) = do + putText "deriving" + + forM_ (deriv_clause_strategy clause) \case + L _ StockStrategy -> space >> putText "stock" + L _ AnyclassStrategy -> space >> putText "anyclass" + L _ NewtypeStrategy -> space >> putText "newtype" + L _ (ViaStrategy _) -> pure () + + putCond + withinColumns + oneLinePrint + multilinePrint + + forM_ (deriv_clause_strategy clause) \case + L _ (ViaStrategy tp) -> do + case cVia of + SameLine -> space + Indent x -> newline >> spaces (x + cDeriving) + + putText "via" + space + putType (getType tp) + _ -> pure () + + putEolComment pos -findCommentBelowLine :: LineBlock -> [Comment] -> Maybe Comment -findCommentBelowLine lb = find commentOnLine where - commentOnLine (Comment _ (H.SrcSpan _ start _ end _) _) = - blockStart lb == start - 1 && blockEnd lb == end - 1 + getType = \case + HsIB _ tp -> tp + XHsImplicitBndrs x -> noExtCon x + + withinColumns PrinterState{currentLine} = + case cMaxColumns of + MaxColumns maxCols -> length currentLine <= maxCols + NoMaxColumns -> True + + oneLinePrint = do + space + putText "(" + sep + (comma >> space) + (fmap putOutputable tys) + putText ")" + + multilinePrint = do + newline + spaces indentation + putText "(" + + forM_ headTy \t -> + space >> putOutputable t + + forM_ tailTy \t -> do + newline + spaces indentation + comma + space + putOutputable t + + newline + spaces indentation + putText ")" + + indentation = + cDeriving + case cFirstField of + Indent x -> x + SameLine -> 0 + + tys + = clause + & deriv_clause_tys + & unLocated + & (if cSortDeriving then sortBy compareOutputable else id) + & fmap hsib_body + + headTy = + listToMaybe tys + + tailTy = + drop 1 tys + +putUnbrokenEnum :: Config -> DataDecl -> P () +putUnbrokenEnum cfg decl = + sep + (space >> putText "|" >> space) + (fmap (putConstructor cfg 0) . dd_cons . dataDefn $ decl) + +putName :: DataDecl -> P () +putName decl@MkDataDecl{..} = + if isInfix decl then do + forM_ firstTvar (\t -> putOutputable t >> space) + putRdrName dataDeclName + space + forM_ secondTvar putOutputable + else do + putRdrName dataDeclName + forM_ (hsq_explicit dataTypeVars) (\t -> space >> putOutputable t) -commentsWithin :: LineBlock -> [Comment] -> [Comment] -commentsWithin lb = filter within where - within (Comment _ (H.SrcSpan _ start _ end _) _) = - start >= blockStart lb && end <= blockEnd lb - -changeDecl :: [Comment] -> Config -> H.Decl LineBlock -> Maybe ChangeLine -changeDecl _ _ (H.DataDecl _ (H.DataType _) Nothing _ [] _) = Nothing -changeDecl allComments cfg@Config{..} (H.DataDecl block (H.DataType _) Nothing dhead decls derivings) - | hasRecordFields = Just $ change block (const $ concat newLines) - | otherwise = Nothing + firstTvar :: Maybe (Located (HsTyVarBndr GhcPs)) + firstTvar + = dataTypeVars + & hsq_explicit + & listToMaybe + + secondTvar :: Maybe (Located (HsTyVarBndr GhcPs)) + secondTvar + = dataTypeVars + & hsq_explicit + & drop 1 + & listToMaybe + +putConstructor :: Config -> Int -> Located (ConDecl GhcPs) -> P () +putConstructor cfg consIndent (L _ cons) = case cons of + ConDeclGADT{..} -> do + -- Put argument to constructor first: + case con_args of + PrefixCon _ -> do + sep + (comma >> space) + (fmap putRdrName con_names) + + InfixCon arg1 arg2 -> do + putType arg1 + space + forM_ con_names putRdrName + space + putType arg2 + RecCon _ -> + error . mconcat $ + [ "Language.Haskell.Stylish.Step.Data.putConstructor: " + , "encountered a GADT with record constructors, not supported yet" + ] + + -- Put type of constructor: + space + putText "::" + space + + when (unLocated con_forall) do + putText "forall" + space + sep space (fmap putOutputable $ hsq_explicit con_qvars) + dot + space + + forM_ con_mb_cxt (putContext cfg . unLocated) + putType con_res_ty + + XConDecl x -> + noExtCon x + ConDeclH98{..} -> + case con_args of + InfixCon arg1 arg2 -> do + putType arg1 + space + putRdrName con_name + space + putType arg2 + PrefixCon xs -> do + putRdrName con_name + unless (null xs) space + sep space (fmap putOutputable xs) + RecCon (L recPos (L posFirst firstArg : args)) -> do + putRdrName con_name + skipToBrace + bracePos <- getCurrentLineLength + putText "{" + let fieldPos = bracePos + 2 + space + + -- Unless everything's configured to be on the same line, put pending + -- comments + unless (cFirstField cfg == SameLine) do + removeCommentTo posFirst >>= mapM_ \c -> putComment c >> sepDecl bracePos + + -- Put first decl field + pad fieldPos >> putConDeclField cfg firstArg + unless (cFirstField cfg == SameLine) (putEolComment posFirst) + + -- Put tail decl fields + forM_ args \(L pos arg) -> do + sepDecl bracePos + removeCommentTo pos >>= mapM_ \c -> + spaces (cFieldComment cfg) >> putComment c >> sepDecl bracePos + comma + space + putConDeclField cfg arg + putEolComment pos + + -- Print docstr after final field + removeCommentToEnd recPos >>= mapM_ \c -> + sepDecl bracePos >> spaces (cFieldComment cfg) >> putComment c + + -- Print whitespace to closing brace + sepDecl bracePos >> putText "}" + RecCon (L _ []) -> do + skipToBrace >> putText "{" + skipToBrace >> putText "}" + + where + -- Jump to the first brace of the first record of the first constructor. + skipToBrace = case (cEquals cfg, cFirstField cfg) of + (_, Indent y) | not (cBreakSingleConstructors cfg) -> newline >> spaces y + (SameLine, SameLine) -> space + (Indent x, Indent y) -> newline >> spaces (x + y + 2) + (SameLine, Indent y) -> newline >> spaces (consIndent + y) + (Indent _, SameLine) -> space + + -- Jump to the next declaration. + sepDecl bracePos = newline >> spaces case (cEquals cfg, cFirstField cfg) of + (_, Indent y) | not (cBreakSingleConstructors cfg) -> y + (SameLine, SameLine) -> bracePos + (Indent x, Indent y) -> x + y + 2 + (SameLine, Indent y) -> bracePos + y - 2 + (Indent x, SameLine) -> bracePos + x - 2 + +putNewtypeConstructor :: Config -> Located (ConDecl GhcPs) -> P () +putNewtypeConstructor cfg (L _ cons) = case cons of + ConDeclH98{..} -> + putRdrName con_name >> case con_args of + PrefixCon xs -> do + unless (null xs) space + sep space (fmap putOutputable xs) + RecCon (L _ [L _posFirst firstArg]) -> do + space + putText "{" + space + putConDeclField cfg firstArg + space + putText "}" + RecCon (L _ _args) -> + error . mconcat $ + [ "Language.Haskell.Stylish.Step.Data.putNewtypeConstructor: " + , "encountered newtype with several arguments" + ] + InfixCon {} -> + error . mconcat $ + [ "Language.Haskell.Stylish.Step.Data.putNewtypeConstructor: " + , "infix newtype constructor" + ] + XConDecl x -> + noExtCon x + ConDeclGADT{} -> + error . mconcat $ + [ "Language.Haskell.Stylish.Step.Data.putNewtypeConstructor: " + , "GADT encountered in newtype" + ] + +putContext :: Config -> HsContext GhcPs -> P () +putContext Config{..} = suffix (space >> putText "=>" >> space) . \case + [L _ (HsParTy _ tp)] | cCurriedContext -> + putType tp + [ctx] -> + putType ctx + ctxs | cCurriedContext -> + sep (space >> putText "=>" >> space) (fmap putType ctxs) + ctxs -> + parenthesize $ sep (comma >> space) (fmap putType ctxs) + +putConDeclField :: Config -> ConDeclField GhcPs -> P () +putConDeclField cfg = \case + ConDeclField{..} -> do + sep + (comma >> space) + (fmap putOutputable cd_fld_names) + space + putText "::" + space + putType' cfg cd_fld_type + XConDeclField{} -> + error . mconcat $ + [ "Language.Haskell.Stylish.Step.Data.putConDeclField: " + , "XConDeclField encountered" + ] + +-- | A variant of 'putType' that takes 'cCurriedContext' into account +putType' :: Config -> Located (HsType GhcPs) -> P () +putType' cfg = \case + L _ (HsForAllTy NoExtField vis bndrs tp) -> do + putText "forall" + space + sep space (fmap putOutputable bndrs) + putText + if vis == ForallVis then "->" + else "." + space + putType' cfg tp + L _ (HsQualTy NoExtField ctx tp) -> do + putContext cfg (unLocated ctx) + putType' cfg tp + other -> putType other + +newOrData :: DataDecl -> String +newOrData decl = if isNewtype decl then "newtype" else "data" + +isGADT :: DataDecl -> Bool +isGADT = any isGADTCons . dd_cons . dataDefn where - hasRecordFields = any - (\qual -> case qual of - (H.QualConDecl _ _ _ (H.RecDecl {})) -> True - _ -> False) - decls - - typeConstructor = "data " <> H.prettyPrint dhead - - -- In any case set @pipeIndent@ such that @|@ is aligned with @=@. - (firstLine, firstLineInit, pipeIndent) = - case cEquals of - SameLine -> (Nothing, typeConstructor <> " = ", length typeConstructor + 1) - Indent n -> (Just [[typeConstructor]], indent n "= ", n) - - newLines = fromMaybe [] firstLine ++ fmap constructors zipped <> [fmap (indent cDeriving . H.prettyPrint) derivings] - zipped = zip decls ([1..] ::[Int]) - - constructors (decl, 1) = processConstructor allComments firstLineInit cfg decl - constructors (decl, _) = processConstructor allComments (indent pipeIndent "| ") cfg decl -changeDecl _ _ _ = Nothing - -processConstructor :: [Comment] -> String -> Config -> H.QualConDecl LineBlock -> [String] -processConstructor allComments init Config{..} (H.QualConDecl _ _ _ (H.RecDecl _ dname (f:fs))) = do - fromMaybe [] firstLine <> n1 <> ns <> [indent fieldIndent "}"] + isGADTCons = \case + L _ (ConDeclGADT {}) -> True + _ -> False + +isNewtype :: DataDecl -> Bool +isNewtype = (== NewType) . dd_ND . dataDefn + +isInfix :: DataDecl -> Bool +isInfix = (== Infix) . dataFixity + +isEnum :: DataDecl -> Bool +isEnum = all isUnary . dd_cons . dataDefn where - n1 = processName firstLinePrefix (extractField f) - ns = fs >>= processName (indent fieldIndent ", ") . extractField - - -- Set @fieldIndent@ such that @,@ is aligned with @{@. - (firstLine, firstLinePrefix, fieldIndent) = - case cFirstField of - SameLine -> - ( Nothing - , init <> H.prettyPrint dname <> " { " - , length init + length (H.prettyPrint dname) + 1 - ) - Indent n -> - ( Just [init <> H.prettyPrint dname] - , indent (length init + n) "{ " - , length init + n - ) - - processName prefix (fnames, _type, lineComment, commentBelowLine) = - [prefix <> intercalate ", " (fmap H.prettyPrint fnames) <> " :: " <> H.prettyPrint _type <> addLineComment lineComment - ] ++ addCommentBelow commentBelowLine - - addLineComment (Just (Comment _ _ c)) = " --" <> c - addLineComment Nothing = "" - - -- Field comment indent is measured from the column with @{@, hence adding of @fieldIndent@ here. - addCommentBelow Nothing = [] - addCommentBelow (Just (Comment _ _ c)) = [indent (fieldIndent + cFieldComment) "--" <> c] - - extractField (H.FieldDecl lb names _type) = - (names, _type, findCommentOnLine lb allComments, findCommentBelowLine lb allComments) - -processConstructor _ init _ decl = [init <> trimLeft (H.prettyPrint decl)] + isUnary = \case + L _ (ConDeclH98 {..}) -> case con_args of + PrefixCon [] -> True + _ -> False + _ -> False + +hasConstructors :: DataDecl -> Bool +hasConstructors = not . null . dd_cons . dataDefn + +singleConstructor :: DataDecl -> Bool +singleConstructor = (== 1) . length . dd_cons . dataDefn + +hasDeriving :: DataDecl -> Bool +hasDeriving = not . null . unLocated . dd_derivs . dataDefn diff --git a/lib/Language/Haskell/Stylish/Step/Imports.hs b/lib/Language/Haskell/Stylish/Step/Imports.hs index 7cb78d4..b89d73f 100644 --- a/lib/Language/Haskell/Stylish/Step/Imports.hs +++ b/lib/Language/Haskell/Stylish/Step/Imports.hs @@ -1,61 +1,78 @@ -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} --------------------------------------------------------------------------------- +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DoAndIfThenElse #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} module Language.Haskell.Stylish.Step.Imports - ( Options (..) - , defaultOptions - , ImportAlign (..) - , ListAlign (..) - , LongListAlign (..) - , EmptyListAlign (..) - , ListPadding (..) - , step - ) where + ( Options (..) + , defaultOptions + , ImportAlign (..) + , ListAlign (..) + , LongListAlign (..) + , EmptyListAlign (..) + , ListPadding (..) + , step + + , printImport + ) where + +-------------------------------------------------------------------------------- +import Control.Monad (forM_, when, void) +import Data.Function ((&), on) +import Data.Functor (($>)) +import Data.Foldable (toList) +import Data.Maybe (isJust) +import Data.List (sortBy) +import Data.List.NonEmpty (NonEmpty(..)) +import qualified Data.List.NonEmpty as NonEmpty +import qualified Data.Map as Map +import qualified Data.Set as Set -------------------------------------------------------------------------------- -import Control.Arrow ((&&&)) -import Control.Monad (void) -import qualified Data.Aeson as A -import qualified Data.Aeson.Types as A -import Data.Char (toLower) -import Data.List (intercalate, sortBy) -import qualified Data.Map as M -import Data.Maybe (isJust, maybeToList) -import Data.Ord (comparing) -import qualified Data.Set as S -import Data.Semigroup (Semigroup ((<>))) -import qualified Language.Haskell.Exts as H +import BasicTypes (StringLiteral (..), + SourceText (..)) +import qualified FastString as FS +import GHC.Hs.Extension (GhcPs) +import qualified GHC.Hs.Extension as GHC +import GHC.Hs.ImpExp +import Module (moduleNameString) +import RdrName (RdrName) +import SrcLoc (Located, GenLocated(..), unLoc) -------------------------------------------------------------------------------- import Language.Haskell.Stylish.Block -import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.Module +import Language.Haskell.Stylish.Ordering +import Language.Haskell.Stylish.Printer import Language.Haskell.Stylish.Step +import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.GHC import Language.Haskell.Stylish.Util + -------------------------------------------------------------------------------- data Options = Options - { importAlign :: ImportAlign - , listAlign :: ListAlign - , padModuleNames :: Bool - , longListAlign :: LongListAlign - , emptyListAlign :: EmptyListAlign - , listPadding :: ListPadding - , separateLists :: Bool - , spaceSurround :: Bool + { importAlign :: ImportAlign + , listAlign :: ListAlign + , padModuleNames :: Bool + , longListAlign :: LongListAlign + , emptyListAlign :: EmptyListAlign + , listPadding :: ListPadding + , separateLists :: Bool + , spaceSurround :: Bool } deriving (Eq, Show) defaultOptions :: Options defaultOptions = Options - { importAlign = Global - , listAlign = AfterAlias - , padModuleNames = True - , longListAlign = Inline - , emptyListAlign = Inherit - , listPadding = LPConstant 4 - , separateLists = True - , spaceSurround = False + { importAlign = Global + , listAlign = AfterAlias + , padModuleNames = True + , longListAlign = Inline + , emptyListAlign = Inherit + , listPadding = LPConstant 4 + , separateLists = True + , spaceSurround = False } data ListPadding @@ -75,6 +92,7 @@ data ListAlign | WithModuleName | WithAlias | AfterAlias + | Repeat deriving (Eq, Show) data EmptyListAlign @@ -83,375 +101,385 @@ data EmptyListAlign deriving (Eq, Show) data LongListAlign - = Inline - | InlineWithBreak - | InlineToMultiline - | Multiline + = Inline -- inline + | InlineWithBreak -- new_line + | InlineToMultiline -- new_line_multiline + | Multiline -- multiline deriving (Eq, Show) -------------------------------------------------------------------------------- - -modifyImportSpecs :: ([H.ImportSpec l] -> [H.ImportSpec l]) - -> H.ImportDecl l -> H.ImportDecl l -modifyImportSpecs f imp = imp {H.importSpecs = f' <$> H.importSpecs imp} - where - f' (H.ImportSpecList l h specs) = H.ImportSpecList l h (f specs) - - --------------------------------------------------------------------------------- -imports :: H.Module l -> [H.ImportDecl l] -imports (H.Module _ _ _ is _) = is -imports _ = [] - - --------------------------------------------------------------------------------- -importName :: H.ImportDecl l -> String -importName i = let (H.ModuleName _ n) = H.importModule i in n - -importPackage :: H.ImportDecl l -> Maybe String -importPackage i = H.importPkg i - - --------------------------------------------------------------------------------- --- | A "compound import name" is import's name and package (if present). For --- instance, if you have an import @Foo.Bar@ from package @foobar@, the full --- name will be @"foobar" Foo.Bar@. -compoundImportName :: H.ImportDecl l -> String -compoundImportName i = - case importPackage i of - Nothing -> importName i - Just pkg -> show pkg ++ " " ++ importName i - - --------------------------------------------------------------------------------- -longestImport :: [H.ImportDecl l] -> Int -longestImport = maximum . map (length . compoundImportName) - - --------------------------------------------------------------------------------- --- | Compare imports for ordering -compareImports :: H.ImportDecl l -> H.ImportDecl l -> Ordering -compareImports = - comparing (map toLower . importName &&& - fmap (map toLower) . importPackage &&& - H.importQualified) - - --------------------------------------------------------------------------------- --- | Remove (or merge) duplicated import specs. --- --- * When something is mentioned twice, it's removed: @A, A@ -> A --- * More general forms take priority: @A, A(..)@ -> @A(..)@ --- * Sometimes we have to combine imports: @A(x), A(y)@ -> @A(x, y)@ --- --- Import specs are always sorted by subsequent steps so we don't have to care --- about preserving order. -deduplicateImportSpecs :: Ord l => H.ImportDecl l -> H.ImportDecl l -deduplicateImportSpecs = - modifyImportSpecs $ - map recomposeImportSpec . - M.toList . M.fromListWith (<>) . - map decomposeImportSpec - --- | What we are importing (variable, class, etc) -data ImportEntity l - -- | A variable - = ImportVar l (H.Name l) - -- | Something that can be imported partially - | ImportClassOrData l (H.Name l) - -- | Something else ('H.IAbs') - | ImportOther l (H.Namespace l) (H.Name l) - deriving (Eq, Ord) - --- | What we are importing from an 'ImportClassOrData' -data ImportPortion l - = ImportSome [H.CName l] -- ^ @A(x, y, z)@ - | ImportAll -- ^ @A(..)@ - -instance Ord l => Semigroup (ImportPortion l) where - ImportSome a <> ImportSome b = ImportSome (setUnion a b) - _ <> _ = ImportAll - -instance Ord l => Monoid (ImportPortion l) where - mempty = ImportSome [] - mappend = (<>) - --- | O(n log n) union. -setUnion :: Ord a => [a] -> [a] -> [a] -setUnion a b = S.toList (S.fromList a `S.union` S.fromList b) - -decomposeImportSpec :: H.ImportSpec l -> (ImportEntity l, ImportPortion l) -decomposeImportSpec x = case x of - -- I checked and it looks like namespace's 'l' is always equal to x's 'l' - H.IAbs l space n -> case space of - H.NoNamespace _ -> (ImportClassOrData l n, ImportSome []) - H.TypeNamespace _ -> (ImportOther l space n, ImportSome []) - H.PatternNamespace _ -> (ImportOther l space n, ImportSome []) - H.IVar l n -> (ImportVar l n, ImportSome []) - H.IThingAll l n -> (ImportClassOrData l n, ImportAll) - H.IThingWith l n names -> (ImportClassOrData l n, ImportSome names) - -recomposeImportSpec :: (ImportEntity l, ImportPortion l) -> H.ImportSpec l -recomposeImportSpec (e, p) = case e of - ImportClassOrData l n -> case p of - ImportSome [] -> H.IAbs l (H.NoNamespace l) n - ImportSome names -> H.IThingWith l n names - ImportAll -> H.IThingAll l n - ImportVar l n -> H.IVar l n - ImportOther l space n -> H.IAbs l space n +step :: Maybe Int -> Options -> Step +step columns = makeStep "Imports (ghc-lib-parser)" . printImports columns -------------------------------------------------------------------------------- --- | The implementation is a bit hacky to get proper sorting for input specs: --- constructors first, followed by functions, and then operators. -compareImportSpecs :: H.ImportSpec l -> H.ImportSpec l -> Ordering -compareImportSpecs = comparing key +printImports :: Maybe Int -> Options -> Lines -> Module -> Lines +printImports maxCols align ls m = applyChanges changes ls where - key :: H.ImportSpec l -> (Int, Bool, String) - key (H.IVar _ x) = (1, isOperator x, nameToString x) - key (H.IAbs _ _ x) = (0, False, nameToString x) - key (H.IThingAll _ x) = (0, False, nameToString x) - key (H.IThingWith _ x _) = (0, False, nameToString x) - + groups = moduleImportGroups m + moduleStats = foldMap importStats . fmap unLoc $ concatMap toList groups + changes = do + group <- groups + pure $ formatGroup maxCols align m moduleStats group + +formatGroup + :: Maybe Int -> Options -> Module -> ImportStats + -> NonEmpty (Located Import) -> Change String +formatGroup maxCols options m moduleStats imports = + let newLines = formatImports maxCols options m moduleStats imports in + change (importBlock imports) (const newLines) + +importBlock :: NonEmpty (Located a) -> Block String +importBlock group = Block + (getStartLineUnsafe $ NonEmpty.head group) + (getEndLineUnsafe $ NonEmpty.last group) + +formatImports + :: Maybe Int -- ^ Max columns. + -> Options -- ^ Options. + -> Module -- ^ Module. + -> ImportStats -- ^ Module stats. + -> NonEmpty (Located Import) -> Lines +formatImports maxCols options m moduleStats rawGroup = + runPrinter_ (PrinterConfig maxCols) [] m do + let + + group + = NonEmpty.sortWith unLocated rawGroup + & mergeImports + + unLocatedGroup = fmap unLocated $ toList group + + align' = importAlign options + padModuleNames' = padModuleNames options + padNames = align' /= None && padModuleNames' + + stats = case align' of + Global -> moduleStats {isAnyQualified = True} + File -> moduleStats + Group -> foldMap importStats unLocatedGroup + None -> mempty + + forM_ group \imp -> printQualified options padNames stats imp >> newline -------------------------------------------------------------------------------- --- | Sort the input spec list inside an 'H.ImportDecl' -sortImportSpecs :: H.ImportDecl l -> H.ImportDecl l -sortImportSpecs = modifyImportSpecs (sortBy compareImportSpecs) +printQualified :: Options -> Bool -> ImportStats -> Located Import -> P () +printQualified Options{..} padNames stats (L _ decl) = do + let decl' = rawImport decl + + putText "import" >> space + + case (isSource decl, isAnySource stats) of + (True, _) -> putText "{-# SOURCE #-}" >> space + (_, True) -> putText " " >> space + _ -> pure () + + when (isSafe decl) (putText "safe" >> space) + + case (isQualified decl, isAnyQualified stats) of + (True, _) -> putText "qualified" >> space + (_, True) -> putText " " >> space + _ -> pure () + + moduleNamePosition <- length <$> getCurrentLine + forM_ (ideclPkgQual decl') $ \pkg -> putText (stringLiteral pkg) >> space + putText (moduleName decl) + + -- Only print spaces if something follows. + when padNames $ + when (isJust (ideclAs decl') || isHiding decl || + not (null $ ideclHiding decl')) $ + putText $ + replicate (isLongestImport stats - importModuleNameLength decl) ' ' + + beforeAliasPosition <- length <$> getCurrentLine + forM_ (ideclAs decl') \(L _ name) -> + space >> putText "as" >> space >> putText (moduleNameString name) + afterAliasPosition <- length <$> getCurrentLine + + when (isHiding decl) (space >> putText "hiding") + + let putOffset = putText $ replicate offset ' ' + offset = case listPadding of + LPConstant n -> n + LPModuleName -> moduleNamePosition + + case snd <$> ideclHiding decl' of + Nothing -> pure () + Just (L _ []) -> case emptyListAlign of + RightAfter -> modifyCurrentLine trimRight >> space >> putText "()" + Inherit -> case listAlign of + NewLine -> + modifyCurrentLine trimRight >> newline >> putOffset >> putText "()" + _ -> space >> putText "()" + Just (L _ imports) -> do + let printedImports = flagEnds $ -- [P ()] + fmap ((printImport separateLists) . unLocated) + (prepareImportList imports) + + -- Since we might need to output the import module name several times, we + -- need to save it to a variable: + wrapPrefix <- case listAlign of + AfterAlias -> pure $ replicate (afterAliasPosition + 1) ' ' + WithAlias -> pure $ replicate (beforeAliasPosition + 1) ' ' + Repeat -> fmap (++ " (") getCurrentLine + WithModuleName -> pure $ replicate (moduleNamePosition + offset) ' ' + NewLine -> pure $ replicate offset ' ' + + let -- Helper + doSpaceSurround = when spaceSurround space + + -- Try to put everything on one line. + printAsSingleLine = forM_ printedImports $ \(imp, start, end) -> do + when start $ putText "(" >> doSpaceSurround + imp + if end then doSpaceSurround >> putText ")" else comma >> space + + -- Try to put everything one by one, wrapping if that fails. + printAsInlineWrapping wprefix = forM_ printedImports $ + \(imp, start, end) -> + patchForRepeatHiding $ wrapping + (do + if start then putText "(" >> doSpaceSurround else space + imp + if end then doSpaceSurround >> putText ")" else comma) + (do + case listAlign of + -- In 'Repeat' mode, end lines with ')' rather than ','. + Repeat | not start -> modifyCurrentLine . withLast $ + \c -> if c == ',' then ')' else c + _ | start && spaceSurround -> + -- Only necessary if spaceSurround is enabled. + modifyCurrentLine trimRight + _ -> pure () + newline + void wprefix + case listAlign of + -- '(' already included in repeat + Repeat -> pure () + -- Print the much needed '(' + _ | start -> putText "(" >> doSpaceSurround + -- Don't bother aligning if we're not in inline mode. + _ | longListAlign /= Inline -> pure () + -- 'Inline + AfterAlias' is really where we want to be careful + -- with spacing. + AfterAlias -> space >> doSpaceSurround + WithModuleName -> pure () + WithAlias -> pure () + NewLine -> pure () + imp + if end then doSpaceSurround >> putText ")" else comma) + + -- Put everything on a separate line. 'spaceSurround' can be + -- ignored. + printAsMultiLine = forM_ printedImports $ \(imp, start, end) -> do + when start $ modifyCurrentLine trimRight -- We added some spaces. + newline + putOffset + if start then putText "( " else putText ", " + imp + when end $ newline >> putOffset >> putText ")" + + case longListAlign of + Multiline -> wrapping + (space >> printAsSingleLine) + printAsMultiLine + Inline | NewLine <- listAlign -> do + modifyCurrentLine trimRight + newline >> putOffset >> printAsInlineWrapping (putText wrapPrefix) + Inline -> space >> printAsInlineWrapping (putText wrapPrefix) + InlineWithBreak -> wrapping + (space >> printAsSingleLine) + (do + modifyCurrentLine trimRight + newline >> putOffset >> printAsInlineWrapping putOffset) + InlineToMultiline -> wrapping + (space >> printAsSingleLine) + (wrapping + (do + modifyCurrentLine trimRight + newline >> putOffset >> printAsSingleLine) + printAsMultiLine) + where + -- We cannot wrap/repeat 'hiding' imports since then we would get multiple + -- imports hiding different things. + patchForRepeatHiding = case listAlign of + Repeat | isHiding decl -> withColumns Nothing + _ -> id -------------------------------------------------------------------------------- --- | Order of imports in sublist is: --- Constructors, accessors/methods, operators. -compareImportSubSpecs :: H.CName l -> H.CName l -> Ordering -compareImportSubSpecs = comparing key - where - key :: H.CName l -> (Int, Bool, String) - key (H.ConName _ x) = (0, False, nameToString x) - key (H.VarName _ x) = (1, isOperator x, nameToString x) +printImport :: Bool -> IE GhcPs -> P () +printImport _ (IEVar _ name) = do + printIeWrappedName name +printImport _ (IEThingAbs _ name) = do + printIeWrappedName name +printImport separateLists (IEThingAll _ name) = do + printIeWrappedName name + when separateLists space + putText "(..)" +printImport _ (IEModuleContents _ (L _ m)) = do + putText "module" + space + putText (moduleNameString m) +printImport separateLists (IEThingWith _ name _wildcard imps _) = do + printIeWrappedName name + when separateLists space + parenthesize $ + sep (comma >> space) (printIeWrappedName <$> imps) +printImport _ (IEGroup _ _ _ ) = + error "Language.Haskell.Stylish.Printer.Imports.printImportExport: unhandled case 'IEGroup'" +printImport _ (IEDoc _ _) = + error "Language.Haskell.Stylish.Printer.Imports.printImportExport: unhandled case 'IEDoc'" +printImport _ (IEDocNamed _ _) = + error "Language.Haskell.Stylish.Printer.Imports.printImportExport: unhandled case 'IEDocNamed'" +printImport _ (XIE ext) = + GHC.noExtCon ext -------------------------------------------------------------------------------- --- | By default, haskell-src-exts pretty-prints --- --- > import Foo (Bar(..)) --- --- but we want --- --- > import Foo (Bar (..)) --- --- instead. -prettyImportSpec :: (Ord l) => Bool -> H.ImportSpec l -> String -prettyImportSpec separate = prettyImportSpec' +printIeWrappedName :: LIEWrappedName RdrName -> P () +printIeWrappedName lie = unLocated lie & \case + IEName n -> putRdrName n + IEPattern n -> putText "pattern" >> space >> putRdrName n + IEType n -> putText "type" >> space >> putRdrName n + +mergeImports :: NonEmpty (Located Import) -> NonEmpty (Located Import) +mergeImports (x :| []) = x :| [] +mergeImports (h :| (t : ts)) + | canMergeImport (unLocated h) (unLocated t) = mergeImports (mergeModuleImport h t :| ts) + | otherwise = h :| mergeImportsTail (t : ts) where - prettyImportSpec' (H.IThingAll _ n) = H.prettyPrint n ++ sep "(..)" - prettyImportSpec' (H.IThingWith _ n cns) = H.prettyPrint n - ++ sep "(" - ++ intercalate ", " - (map H.prettyPrint $ sortBy compareImportSubSpecs cns) - ++ ")" - prettyImportSpec' x = H.prettyPrint x + mergeImportsTail (x : y : ys) + | canMergeImport (unLocated x) (unLocated y) = mergeImportsTail ((mergeModuleImport x y) : ys) + | otherwise = x : mergeImportsTail (y : ys) + mergeImportsTail xs = xs - sep = if separate then (' ' :) else id +moduleName :: Import -> String +moduleName + = moduleNameString + . unLocated + . ideclName + . rawImport -------------------------------------------------------------------------------- -prettyImport :: (Ord l, Show l) => - Maybe Int -> Options -> Bool -> Bool -> Int -> H.ImportDecl l -> [String] -prettyImport columns Options{..} padQualified padName longest imp - | (void `fmap` H.importSpecs imp) == emptyImportSpec = emptyWrap - | otherwise = case longListAlign of - Inline -> inlineWrap - InlineWithBreak -> longListWrapper inlineWrap inlineWithBreakWrap - InlineToMultiline -> longListWrapper inlineWrap inlineToMultilineWrap - Multiline -> longListWrapper inlineWrap multilineWrap - where - emptyImportSpec = Just (H.ImportSpecList () False []) - -- "import" + space + qualifiedLength has space in it. - listPadding' = listPaddingValue (6 + 1 + qualifiedLength) listPadding - where - qualifiedLength = - if null qualified then 0 else 1 + sum (map length qualified) - - longListWrapper shortWrap longWrap - | listAlign == NewLine - || length shortWrap > 1 - || exceedsColumns (length (head shortWrap)) - = longWrap - | otherwise = shortWrap - - emptyWrap = case emptyListAlign of - Inherit -> inlineWrap - RightAfter -> [paddedNoSpecBase ++ " ()"] - - inlineWrap = inlineWrapper - $ mapSpecs - $ withInit (++ ",") - . withHead (("(" ++ maybeSpace) ++) - . withLast (++ (maybeSpace ++ ")")) - - inlineWrapper = case listAlign of - NewLine -> (paddedNoSpecBase :) . wrapRestMaybe columns listPadding' - WithModuleName -> wrapMaybe columns paddedBase (withModuleNameBaseLength + 4) - WithAlias -> wrapMaybe columns paddedBase (inlineBaseLength + 1) - -- Add 1 extra space to ensure same padding as in original code. - AfterAlias -> withTail ((' ' : maybeSpace) ++) - . wrapMaybe columns paddedBase (afterAliasBaseLength + 1) - - inlineWithBreakWrap = paddedNoSpecBase : wrapRestMaybe columns listPadding' - ( mapSpecs - $ withInit (++ ",") - . withHead (("(" ++ maybeSpace) ++) - . withLast (++ (maybeSpace ++ ")"))) - - inlineToMultilineWrap - | length inlineWithBreakWrap > 2 - || any (exceedsColumns . length) (tail inlineWithBreakWrap) - = multilineWrap - | otherwise = inlineWithBreakWrap - - -- 'wrapRest 0' ensures that every item of spec list is on new line. - multilineWrap = paddedNoSpecBase : wrapRest 0 listPadding' - ( mapSpecs - ( withHead ("( " ++) - . withTail (", " ++)) - ++ closer) - where - closer = if null importSpecs - then [] - else [")"] - - paddedBase = base $ padImport $ compoundImportName imp - - paddedNoSpecBase = base $ padImportNoSpec $ compoundImportName imp - - padImport = if hasExtras && padName - then padRight longest - else id - - padImportNoSpec = if (isJust (H.importAs imp) || hasHiding) && padName - then padRight longest - else id - - base' baseName importAs hasHiding' = unwords $ concat $ - [ ["import"] - , source - , safe - , qualified - , [baseName] - , importAs - , hasHiding' - ] - - base baseName = base' baseName - ["as " ++ as | H.ModuleName _ as <- maybeToList $ H.importAs imp] - ["hiding" | hasHiding] - - inlineBaseLength = length $ - base' (padImport $ compoundImportName imp) [] [] - - withModuleNameBaseLength = length $ base' "" [] [] - - afterAliasBaseLength = length $ base' (padImport $ compoundImportName imp) - ["as " ++ as | H.ModuleName _ as <- maybeToList $ H.importAs imp] [] - - (hasHiding, importSpecs) = case H.importSpecs imp of - Just (H.ImportSpecList _ h l) -> (h, Just l) - _ -> (False, Nothing) - - hasExtras = isJust (H.importAs imp) || isJust (H.importSpecs imp) - - qualified - | H.importQualified imp = ["qualified"] - | padQualified = - if H.importSrc imp - then [] - else if H.importSafe imp - then [" "] - else [" "] - | otherwise = [] - - safe - | H.importSafe imp = ["safe"] - | otherwise = [] - - source - | H.importSrc imp = ["{-# SOURCE #-}"] - | otherwise = [] - - mapSpecs f = case importSpecs of - Nothing -> [] -- Import everything - Just [] -> ["()"] -- Instance only imports - Just is -> f $ map (prettyImportSpec separateLists) is - - maybeSpace = case spaceSurround of - True -> " " - False -> "" - - exceedsColumns i = case columns of - Nothing -> False -- No number exceeds a maximum column count of - -- Nothing, because there is no limit to exceed. - Just c -> i > c - +data ImportStats = ImportStats + { isLongestImport :: !Int + , isAnySource :: !Bool + , isAnyQualified :: !Bool + , isAnySafe :: !Bool + } --------------------------------------------------------------------------------- -prettyImportGroup :: Maybe Int -> Options -> Bool -> Int - -> [H.ImportDecl LineBlock] - -> Lines -prettyImportGroup columns align fileAlign longest imps = - concatMap (prettyImport columns align padQual padName longest') $ - sortBy compareImports imps - where - align' = importAlign align - padModuleNames' = padModuleNames align +instance Semigroup ImportStats where + l <> r = ImportStats + { isLongestImport = isLongestImport l `max` isLongestImport r + , isAnySource = isAnySource l || isAnySource r + , isAnyQualified = isAnyQualified l || isAnyQualified r + , isAnySafe = isAnySafe l || isAnySafe r + } - longest' = case align' of - Group -> longestImport imps - _ -> longest +instance Monoid ImportStats where + mappend = (<>) + mempty = ImportStats 0 False False False - padName = align' /= None && padModuleNames' +importStats :: Import -> ImportStats +importStats i = + ImportStats (importModuleNameLength i) (isSource i) (isQualified i) (isSafe i) - padQual = case align' of - Global -> True - File -> fileAlign - Group -> any H.importQualified imps - None -> False +-- Computes length till module name, includes package name. +-- TODO: this should reuse code with the printer +importModuleNameLength :: Import -> Int +importModuleNameLength imp = + (case ideclPkgQual (rawImport imp) of + Nothing -> 0 + Just sl -> 1 + length (stringLiteral sl)) + + (length $ moduleName imp) -------------------------------------------------------------------------------- -step :: Maybe Int -> Options -> Step -step columns = makeStep "Imports" . step' columns +stringLiteral :: StringLiteral -> String +stringLiteral sl = case sl_st sl of + NoSourceText -> FS.unpackFS $ sl_fs sl + SourceText s -> s -------------------------------------------------------------------------------- -step' :: Maybe Int -> Options -> Lines -> Module -> Lines -step' columns align ls (module', _) = applyChanges - [ change block $ const $ - prettyImportGroup columns align fileAlign longest importGroup - | (block, importGroup) <- groups - ] - ls - where - imps = map (sortImportSpecs . deduplicateImportSpecs) $ - imports $ fmap linesFromSrcSpan module' - longest = longestImport imps - groups = groupAdjacent [(H.ann i, i) | i <- imps] - - fileAlign = case importAlign align of - File -> any H.importQualified imps - _ -> False +isQualified :: Import -> Bool +isQualified + = (/=) NotQualified + . ideclQualified + . rawImport + +isHiding :: Import -> Bool +isHiding + = maybe False fst + . ideclHiding + . rawImport + +isSource :: Import -> Bool +isSource + = ideclSource + . rawImport + +isSafe :: Import -> Bool +isSafe + = ideclSafe + . rawImport -------------------------------------------------------------------------------- -listPaddingValue :: Int -> ListPadding -> Int -listPaddingValue _ (LPConstant n) = n -listPaddingValue n LPModuleName = n +-- | Cleans up an import item list. +-- +-- * Sorts import items. +-- * Sort inner import lists, e.g. `import Control.Monad (Monad (return, join))` +-- * Removes duplicates from import lists. +prepareImportList :: [LIE GhcPs] -> [LIE GhcPs] +prepareImportList = + sortBy compareLIE . map (fmap prepareInner) . + concatMap (toList . snd) . Map.toAscList . mergeByName + where + mergeByName :: [LIE GhcPs] -> Map.Map RdrName (NonEmpty (LIE GhcPs)) + mergeByName imports0 = Map.fromListWith + -- Note that ideally every NonEmpty will just have a single entry and we + -- will be able to merge everything into that entry. Exotic imports can + -- mess this up, though. So they end up in the tail of the list. + (\(x :| xs) (y :| ys) -> case ieMerge (unLocated x) (unLocated y) of + Just z -> (x $> z) :| (xs ++ ys) -- Keep source from `x` + Nothing -> x :| (xs ++ y : ys)) + [(ieName $ unLocated imp, imp :| []) | imp <- imports0] + + prepareInner :: IE GhcPs -> IE GhcPs + prepareInner = \case + -- Simplify `A ()` to `A`. + IEThingWith x n NoIEWildcard [] [] -> IEThingAbs x n + IEThingWith x n w ns fs -> + IEThingWith x n w (sortBy (compareWrappedName `on` unLoc) ns) fs + ie -> ie + + -- Merge two import items, assuming they have the same name. + ieMerge :: IE GhcPs -> IE GhcPs -> Maybe (IE GhcPs) + ieMerge l@(IEVar _ _) _ = Just l + ieMerge _ r@(IEVar _ _) = Just r + ieMerge (IEThingAbs _ _) r = Just r + ieMerge l (IEThingAbs _ _) = Just l + ieMerge l@(IEThingAll _ _) _ = Just l + ieMerge _ r@(IEThingAll _ _) = Just r + ieMerge (IEThingWith x0 n0 w0 ns0 []) (IEThingWith _ _ w1 ns1 []) + | w0 /= w1 = Nothing + | otherwise = Just $ + -- TODO: sort the `ns0 ++ ns1`? + IEThingWith x0 n0 w0 (nubOn (unwrapName . unLoc) $ ns0 ++ ns1) [] + ieMerge _ _ = Nothing --------------------------------------------------------------------------------- -instance A.FromJSON ListPadding where - parseJSON (A.String "module_name") = return LPModuleName - parseJSON (A.Number n) | n' >= 1 = return $ LPConstant n' - where - n' = truncate n - parseJSON v = A.typeMismatch "'module_name' or >=1 number" v +-------------------------------------------------------------------------------- +nubOn :: Ord k => (a -> k) -> [a] -> [a] +nubOn f = go Set.empty + where + go _ [] = [] + go acc (x : xs) + | y `Set.member` acc = go acc xs + | otherwise = x : go (Set.insert y acc) xs + where + y = f x diff --git a/lib/Language/Haskell/Stylish/Step/LanguagePragmas.hs b/lib/Language/Haskell/Stylish/Step/LanguagePragmas.hs index c9d461f..ddfdeb0 100644 --- a/lib/Language/Haskell/Stylish/Step/LanguagePragmas.hs +++ b/lib/Language/Haskell/Stylish/Step/LanguagePragmas.hs @@ -1,4 +1,7 @@ -------------------------------------------------------------------------------- +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} module Language.Haskell.Stylish.Step.LanguagePragmas ( Style (..) , step @@ -8,13 +11,23 @@ module Language.Haskell.Stylish.Step.LanguagePragmas -------------------------------------------------------------------------------- +import Data.List.NonEmpty (NonEmpty, fromList, toList) import qualified Data.Set as S -import qualified Language.Haskell.Exts as H +import Data.Text (Text) +import qualified Data.Text as T + + +-------------------------------------------------------------------------------- +import qualified GHC.Hs as Hs +import SrcLoc (RealSrcSpan, realSrcSpanStart, + srcLocLine, srcSpanEndLine, + srcSpanStartLine) -------------------------------------------------------------------------------- import Language.Haskell.Stylish.Block import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.Module import Language.Haskell.Stylish.Step import Language.Haskell.Stylish.Util @@ -28,19 +41,6 @@ data Style -------------------------------------------------------------------------------- -pragmas :: H.Module l -> [(l, [String])] -pragmas (H.Module _ _ ps _ _) = - [(l, map nameToString names) | H.LanguagePragma l names <- ps] -pragmas _ = [] - - --------------------------------------------------------------------------------- --- | The start of the first block -firstLocation :: [(Block a, [String])] -> Int -firstLocation = minimum . map (blockStart . fst) - - --------------------------------------------------------------------------------- verticalPragmas :: String -> Int -> Bool -> [String] -> Lines verticalPragmas lg longest align pragmas' = [ "{-# " ++ lg ++ " " ++ pad pragma ++ " #-}" @@ -91,10 +91,10 @@ prettyPragmas lp cols _ align CompactLine = compactLinePragmas lp cols ali -------------------------------------------------------------------------------- -- | Filter redundant (and duplicate) pragmas out of the groups. As a side -- effect, we also sort the pragmas in their group... -filterRedundant :: (String -> Bool) - -> [(l, [String])] - -> [(l, [String])] -filterRedundant isRedundant' = snd . foldr filterRedundant' (S.empty, []) +filterRedundant :: (Text -> Bool) + -> [(l, NonEmpty Text)] + -> [(l, [Text])] +filterRedundant isRedundant' = snd . foldr filterRedundant' (S.empty, []) . fmap (fmap toList) where filterRedundant' (l, xs) (known, zs) | S.null xs' = (known', zs) @@ -111,38 +111,54 @@ step = ((((makeStep "LanguagePragmas" .) .) .) .) . step' -------------------------------------------------------------------------------- step' :: Maybe Int -> Style -> Bool -> Bool -> String -> Lines -> Module -> Lines -step' columns style align removeRedundant lngPrefix ls (module', _) - | null pragmas' = ls - | otherwise = applyChanges changes ls +step' columns style align removeRedundant lngPrefix ls m + | null languagePragmas = ls + | otherwise = applyChanges changes ls where isRedundant' - | removeRedundant = isRedundant module' + | removeRedundant = isRedundant m | otherwise = const False - pragmas' = pragmas $ fmap linesFromSrcSpan module' - longest = maximum $ map length $ snd =<< pragmas' - groups = [(b, concat pgs) | (b, pgs) <- groupAdjacent pragmas'] - changes = - [ change b (const $ prettyPragmas lngPrefix columns longest align style pg) - | (b, pg) <- filterRedundant isRedundant' groups - ] + languagePragmas = moduleLanguagePragmas m + + convertFstToBlock :: [(RealSrcSpan, a)] -> [(Block String, a)] + convertFstToBlock = fmap \(rspan, a) -> + (Block (srcSpanStartLine rspan) (srcSpanEndLine rspan), a) + + groupAdjacent' = + fmap turnSndBackToNel . groupAdjacent . fmap (fmap toList) + where + turnSndBackToNel (a, bss) = (a, fromList . concat $ bss) + + longest :: Int + longest = maximum $ map T.length $ toList . snd =<< languagePragmas + + groups :: [(Block String, NonEmpty Text)] + groups = [(b, pgs) | (b, pgs) <- groupAdjacent' (convertFstToBlock languagePragmas)] + + changes = + [ change b (const $ prettyPragmas lngPrefix columns longest align style (fmap T.unpack pg)) + | (b, pg) <- filterRedundant isRedundant' groups + ] -------------------------------------------------------------------------------- -- | Add a LANGUAGE pragma to a module if it is not present already. -addLanguagePragma :: String -> String -> H.Module H.SrcSpanInfo -> [Change String] +addLanguagePragma :: String -> String -> Module -> [Change String] addLanguagePragma lg prag modu | prag `elem` present = [] | otherwise = [insert line ["{-# " ++ lg ++ " " ++ prag ++ " #-}"]] where - pragmas' = pragmas (fmap linesFromSrcSpan modu) - present = concatMap snd pragmas' - line = if null pragmas' then 1 else firstLocation pragmas' + pragmas' = moduleLanguagePragmas modu + present = concatMap ((fmap T.unpack) . toList . snd) pragmas' + line = if null pragmas' then 1 else firstLocation pragmas' + firstLocation :: [(RealSrcSpan, NonEmpty Text)] -> Int + firstLocation = minimum . fmap (srcLocLine . realSrcSpanStart . fst) -------------------------------------------------------------------------------- -- | Check if a language pragma is redundant. We can't do this for all pragmas, -- but we do a best effort. -isRedundant :: H.Module H.SrcSpanInfo -> String -> Bool +isRedundant :: Module -> Text -> Bool isRedundant m "ViewPatterns" = isRedundantViewPatterns m isRedundant m "BangPatterns" = isRedundantBangPatterns m isRedundant _ _ = False @@ -150,13 +166,29 @@ isRedundant _ _ = False -------------------------------------------------------------------------------- -- | Check if the ViewPatterns language pragma is redundant. -isRedundantViewPatterns :: H.Module H.SrcSpanInfo -> Bool -isRedundantViewPatterns m = null - [() | H.PViewPat {} <- everything m :: [H.Pat H.SrcSpanInfo]] +isRedundantViewPatterns :: Module -> Bool +isRedundantViewPatterns = null . queryModule getViewPat + where + getViewPat :: Hs.Pat Hs.GhcPs -> [()] + getViewPat = \case + Hs.ViewPat{} -> [()] + _ -> [] -------------------------------------------------------------------------------- -- | Check if the BangPatterns language pragma is redundant. -isRedundantBangPatterns :: H.Module H.SrcSpanInfo -> Bool -isRedundantBangPatterns m = null - [() | H.PBangPat _ _ <- everything m :: [H.Pat H.SrcSpanInfo]] +isRedundantBangPatterns :: Module -> Bool +isRedundantBangPatterns modul = + (null $ queryModule getBangPat modul) && + (null $ queryModule getMatchStrict modul) + where + getBangPat :: Hs.Pat Hs.GhcPs -> [()] + getBangPat = \case + Hs.BangPat{} -> [()] + _ -> [] + + getMatchStrict :: Hs.Match Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [()] + getMatchStrict (Hs.XMatch m) = Hs.noExtCon m + getMatchStrict (Hs.Match _ ctx _ _) = case ctx of + Hs.FunRhs _ _ Hs.SrcStrict -> [()] + _ -> [] diff --git a/lib/Language/Haskell/Stylish/Step/ModuleHeader.hs b/lib/Language/Haskell/Stylish/Step/ModuleHeader.hs new file mode 100644 index 0000000..58752fe --- /dev/null +++ b/lib/Language/Haskell/Stylish/Step/ModuleHeader.hs @@ -0,0 +1,222 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} +module Language.Haskell.Stylish.Step.ModuleHeader + ( Config (..) + , defaultConfig + , step + ) where + +-------------------------------------------------------------------------------- +import ApiAnnotation (AnnKeywordId (..), + AnnotationComment (..)) +import Control.Monad (forM_, join, when) +import Data.Bifunctor (second) +import Data.Foldable (find, toList) +import Data.Function ((&)) +import qualified Data.List as L +import Data.List.NonEmpty (NonEmpty (..)) +import qualified Data.List.NonEmpty as NonEmpty +import Data.Maybe (isJust, listToMaybe) +import qualified GHC.Hs.Doc as GHC +import GHC.Hs.Extension (GhcPs) +import qualified GHC.Hs.ImpExp as GHC +import qualified Module as GHC +import SrcLoc (GenLocated (..), + Located, RealLocated, + SrcSpan (..), + srcSpanEndLine, + srcSpanStartLine, unLoc) +import Util (notNull) + +-------------------------------------------------------------------------------- +import Language.Haskell.Stylish.Block +import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.GHC +import Language.Haskell.Stylish.Module +import Language.Haskell.Stylish.Ordering +import Language.Haskell.Stylish.Printer +import Language.Haskell.Stylish.Step +import qualified Language.Haskell.Stylish.Step.Imports as Imports + + +data Config = Config + { indent :: Int + , sort :: Bool + , separateLists :: Bool + } + +defaultConfig :: Config +defaultConfig = Config + { indent = 4 + , sort = True + , separateLists = True + } + +step :: Config -> Step +step = makeStep "Module header" . printModuleHeader + +printModuleHeader :: Config -> Lines -> Module -> Lines +printModuleHeader conf ls m = + let + header = moduleHeader m + name = rawModuleName header + haddocks = rawModuleHaddocks header + exports = rawModuleExports header + annotations = rawModuleAnnotations m + + relevantComments :: [RealLocated AnnotationComment] + relevantComments + = moduleComments m + & rawComments + & dropAfterLocated exports + & dropBeforeLocated name + + -- TODO: pass max columns? + printedModuleHeader = runPrinter_ (PrinterConfig Nothing) relevantComments + m (printHeader conf name exports haddocks) + + getBlock loc = + Block <$> fmap getStartLineUnsafe loc <*> fmap getEndLineUnsafe loc + + adjustOffsetFrom :: Block a -> Block a -> Maybe (Block a) + adjustOffsetFrom (Block s0 _) b2@(Block s1 e1) + | s0 >= s1 && s0 >= e1 = Nothing + | s0 >= s1 = Just (Block (s0 + 1) e1) + | otherwise = Just b2 + + nameBlock = + getBlock name + + exportsBlock = + join $ adjustOffsetFrom <$> nameBlock <*> getBlock exports + + whereM :: Maybe SrcSpan + whereM + = annotations + & filter (\(((_, w), _)) -> w == AnnWhere) + & fmap (head . snd) -- get position of annot + & L.sort + & listToMaybe + + isModuleHeaderWhere :: Block a -> Bool + isModuleHeaderWhere w + = not + . overlapping + $ [w] <> toList nameBlock <> toList exportsBlock + + toLineBlock :: SrcSpan -> Block a + toLineBlock (RealSrcSpan s) = Block (srcSpanStartLine s) (srcSpanEndLine s) + toLineBlock s + = error + $ "'where' block was not a RealSrcSpan" <> show s + + whereBlock + = whereM + & fmap toLineBlock + & find isModuleHeaderWhere + + deletes = + fmap delete $ mergeAdjacent $ toList nameBlock <> toList exportsBlock <> toList whereBlock + + startLine = + maybe 1 blockStart nameBlock + + additions = [insert startLine printedModuleHeader] + + changes = deletes <> additions + in + applyChanges changes ls + +printHeader + :: Config + -> Maybe (Located GHC.ModuleName) + -> Maybe (Located [GHC.LIE GhcPs]) + -> Maybe GHC.LHsDocString + -> P () +printHeader conf mname mexps _ = do + forM_ mname \(L loc name) -> do + putText "module" + space + putText (showOutputable name) + attachEolComment loc + + maybe + (when (isJust mname) do newline >> spaces (indent conf) >> putText "where") + (printExportList conf) + mexps + +attachEolComment :: SrcSpan -> P () +attachEolComment = \case + UnhelpfulSpan _ -> pure () + RealSrcSpan rspan -> + removeLineComment (srcSpanStartLine rspan) >>= mapM_ \c -> space >> putComment c + +attachEolCommentEnd :: SrcSpan -> P () +attachEolCommentEnd = \case + UnhelpfulSpan _ -> pure () + RealSrcSpan rspan -> + removeLineComment (srcSpanEndLine rspan) >>= mapM_ \c -> space >> putComment c + +printExportList :: Config -> Located [GHC.LIE GhcPs] -> P () +printExportList conf (L srcLoc exports) = do + newline + doIndent >> putText "(" >> when (notNull exports) space + + exportsWithComments <- fmap (second doSort) <$> groupAttachedComments exports + + printExports exportsWithComments + + putText ")" >> space >> putText "where" >> attachEolCommentEnd srcLoc + where + -- 'doIndent' is @x@: + -- + -- > module Foo + -- > xxxx( foo + -- > xxxx, bar + -- > xxxx) where + -- + -- 'doHang' is @y@: + -- + -- > module Foo + -- > xxxx( -- Some comment + -- > xxxxyyfoo + -- > xxxx) where + doIndent = spaces (indent conf) + doHang = pad (indent conf + 2) + + doSort = if sort conf then NonEmpty.sortBy compareLIE else id + + printExports :: [([AnnotationComment], NonEmpty (GHC.LIE GhcPs))] -> P () + printExports (([], firstInGroup :| groupRest) : rest) = do + printExport firstInGroup + newline + doIndent + printExportsGroupTail groupRest + printExportsTail rest + printExports ((firstComment : comments, firstExport :| groupRest) : rest) = do + putComment firstComment >> newline >> doIndent + forM_ comments \c -> doHang >> putComment c >> newline >> doIndent + doHang + printExport firstExport + newline + doIndent + printExportsGroupTail groupRest + printExportsTail rest + printExports [] = + newline >> doIndent + + printExportsTail :: [([AnnotationComment], NonEmpty (GHC.LIE GhcPs))] -> P () + printExportsTail = mapM_ \(comments, exported) -> do + forM_ comments \c -> doHang >> putComment c >> newline >> doIndent + forM_ exported \export -> do + comma >> space >> printExport export + newline >> doIndent + + printExportsGroupTail :: [GHC.LIE GhcPs] -> P () + printExportsGroupTail (x : xs) = printExportsTail [([], x :| xs)] + printExportsGroupTail [] = pure () + + -- NOTE(jaspervdj): This code is almost the same as the import printing + -- in 'Imports' and should be merged. + printExport :: GHC.LIE GhcPs -> P () + printExport = Imports.printImport (separateLists conf) . unLoc diff --git a/lib/Language/Haskell/Stylish/Step/SimpleAlign.hs b/lib/Language/Haskell/Stylish/Step/SimpleAlign.hs index 5e61123..f8aea50 100644 --- a/lib/Language/Haskell/Stylish/Step/SimpleAlign.hs +++ b/lib/Language/Haskell/Stylish/Step/SimpleAlign.hs @@ -1,128 +1,202 @@ -------------------------------------------------------------------------------- +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeFamilies #-} module Language.Haskell.Stylish.Step.SimpleAlign ( Config (..) + , Align (..) , defaultConfig , step ) where -------------------------------------------------------------------------------- -import Data.Data (Data) -import Data.List (foldl') -import Data.Maybe (maybeToList) -import qualified Language.Haskell.Exts as H +import Data.Either (partitionEithers) +import Data.Foldable (toList) +import Data.List (foldl', foldl1', sortOn) +import Data.Maybe (fromMaybe) +import qualified GHC.Hs as Hs +import qualified SrcLoc as S -------------------------------------------------------------------------------- import Language.Haskell.Stylish.Align import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.Module import Language.Haskell.Stylish.Step import Language.Haskell.Stylish.Util -------------------------------------------------------------------------------- data Config = Config - { cCases :: !Bool - , cTopLevelPatterns :: !Bool - , cRecords :: !Bool + { cCases :: Align + , cTopLevelPatterns :: Align + , cRecords :: Align + , cMultiWayIf :: Align } deriving (Show) +data Align + = Always + | Adjacent + | Never + deriving (Eq, Show) --------------------------------------------------------------------------------- defaultConfig :: Config defaultConfig = Config - { cCases = True - , cTopLevelPatterns = True - , cRecords = True + { cCases = Always + , cTopLevelPatterns = Always + , cRecords = Always + , cMultiWayIf = Always } +groupAlign :: Align -> [Alignable S.RealSrcSpan] -> [[Alignable S.RealSrcSpan]] +groupAlign a xs = case a of + Never -> [] + Adjacent -> byLine . sortOn (S.srcSpanStartLine . aLeft) $ xs + Always -> [xs] + where + byLine = map toList . groupByLine aLeft + -------------------------------------------------------------------------------- -cases :: Data l => H.Module l -> [[H.Alt l]] -cases modu = [alts | H.Case _ _ alts <- everything modu] +type Record = [S.Located (Hs.ConDeclField Hs.GhcPs)] -------------------------------------------------------------------------------- --- | For this to work well, we require a way to merge annotations. This merge --- operation should follow the semigroup laws. -altToAlignable :: (l -> l -> l) -> H.Alt l -> Maybe (Alignable l) -altToAlignable _ (H.Alt _ _ _ (Just _)) = Nothing -altToAlignable _ (H.Alt ann pat rhs@(H.UnGuardedRhs _ _) Nothing) = Just $ - Alignable - { aContainer = ann - , aLeft = H.ann pat - , aRight = H.ann rhs - , aRightLead = length "-> " - } -altToAlignable - merge - (H.Alt ann pat (H.GuardedRhss _ [H.GuardedRhs _ guards rhs]) Nothing) = - -- We currently only support the case where an alternative has a single - -- guarded RHS. If there are more, we would need to return multiple - -- `Alignable`s from this function, which would be a significant change. - Just $ Alignable - { aContainer = ann - , aLeft = foldl' merge (H.ann pat) (map H.ann guards) - , aRight = H.ann rhs - , aRightLead = length "-> " - } -altToAlignable _ _ = Nothing +records :: S.Located (Hs.HsModule Hs.GhcPs) -> [Record] +records modu = do + let decls = map S.unLoc (Hs.hsmodDecls (S.unLoc modu)) + tyClDecls = [ tyClDecl | Hs.TyClD _ tyClDecl <- decls ] + dataDecls = [ d | d@(Hs.DataDecl _ _ _ _ _) <- tyClDecls ] + dataDefns = map Hs.tcdDataDefn dataDecls + d@Hs.ConDeclH98 {} <- concatMap getConDecls dataDefns + case Hs.con_args d of + Hs.RecCon rec -> [S.unLoc rec] + _ -> [] + where + getConDecls :: Hs.HsDataDefn Hs.GhcPs -> [Hs.ConDecl Hs.GhcPs] + getConDecls d@Hs.HsDataDefn {} = map S.unLoc $ Hs.dd_cons d + getConDecls (Hs.XHsDataDefn x) = Hs.noExtCon x -------------------------------------------------------------------------------- -tlpats :: Data l => H.Module l -> [[H.Match l]] -tlpats modu = [matches | H.FunBind _ matches <- everything modu] +recordToAlignable :: Config -> Record -> [[Alignable S.RealSrcSpan]] +recordToAlignable conf = groupAlign (cRecords conf) . fromMaybe [] . traverse fieldDeclToAlignable -------------------------------------------------------------------------------- -matchToAlignable :: H.Match l -> Maybe (Alignable l) -matchToAlignable (H.InfixMatch _ _ _ _ _ _) = Nothing -matchToAlignable (H.Match _ _ [] _ _) = Nothing -matchToAlignable (H.Match _ _ _ _ (Just _)) = Nothing -matchToAlignable (H.Match ann name pats rhs Nothing) = Just $ Alignable - { aContainer = ann - , aLeft = last (H.ann name : map H.ann pats) - , aRight = H.ann rhs - , aRightLead = length "= " +fieldDeclToAlignable + :: S.Located (Hs.ConDeclField Hs.GhcPs) -> Maybe (Alignable S.RealSrcSpan) +fieldDeclToAlignable (S.L _ (Hs.XConDeclField x)) = Hs.noExtCon x +fieldDeclToAlignable (S.L matchLoc (Hs.ConDeclField _ names ty _)) = do + matchPos <- toRealSrcSpan matchLoc + leftPos <- toRealSrcSpan $ S.getLoc $ last names + tyPos <- toRealSrcSpan $ S.getLoc ty + Just $ Alignable + { aContainer = matchPos + , aLeft = leftPos + , aRight = tyPos + , aRightLead = length ":: " } -------------------------------------------------------------------------------- -records :: H.Module l -> [[H.FieldDecl l]] -records modu = - [ fields - | H.Module _ _ _ _ decls <- [modu] - , H.DataDecl _ _ _ _ cons _ <- decls - , H.QualConDecl _ _ _ (H.RecDecl _ _ fields) <- cons - ] +matchGroupToAlignable + :: Config + -> Hs.MatchGroup Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) + -> [[Alignable S.RealSrcSpan]] +matchGroupToAlignable _conf (Hs.XMatchGroup x) = Hs.noExtCon x +matchGroupToAlignable conf (Hs.MG _ alts _) = cases' ++ patterns' + where + (cases, patterns) = partitionEithers . fromMaybe [] $ traverse matchToAlignable (S.unLoc alts) + cases' = groupAlign (cCases conf) cases + patterns' = groupAlign (cTopLevelPatterns conf) patterns -------------------------------------------------------------------------------- -fieldDeclToAlignable :: H.FieldDecl a -> Maybe (Alignable a) -fieldDeclToAlignable (H.FieldDecl ann names ty) = Just $ Alignable - { aContainer = ann - , aLeft = H.ann (last names) - , aRight = H.ann ty - , aRightLead = length ":: " +matchToAlignable + :: S.Located (Hs.Match Hs.GhcPs (Hs.LHsExpr Hs.GhcPs)) + -> Maybe (Either (Alignable S.RealSrcSpan) (Alignable S.RealSrcSpan)) +matchToAlignable (S.L matchLoc m@(Hs.Match _ Hs.CaseAlt pats@(_ : _) grhss)) = do + let patsLocs = map S.getLoc pats + pat = last patsLocs + guards = getGuards m + guardsLocs = map S.getLoc guards + left = foldl' S.combineSrcSpans pat guardsLocs + body <- rhsBody grhss + matchPos <- toRealSrcSpan matchLoc + leftPos <- toRealSrcSpan left + rightPos <- toRealSrcSpan $ S.getLoc body + Just . Left $ Alignable + { aContainer = matchPos + , aLeft = leftPos + , aRight = rightPos + , aRightLead = length "-> " + } +matchToAlignable (S.L matchLoc (Hs.Match _ (Hs.FunRhs name _ _) pats@(_ : _) grhss)) = do + body <- unguardedRhsBody grhss + let patsLocs = map S.getLoc pats + nameLoc = S.getLoc name + left = last (nameLoc : patsLocs) + bodyLoc = S.getLoc body + matchPos <- toRealSrcSpan matchLoc + leftPos <- toRealSrcSpan left + bodyPos <- toRealSrcSpan bodyLoc + Just . Right $ Alignable + { aContainer = matchPos + , aLeft = leftPos + , aRight = bodyPos + , aRightLead = length "= " } +matchToAlignable (S.L _ (Hs.XMatch x)) = Hs.noExtCon x +matchToAlignable (S.L _ (Hs.Match _ _ _ _)) = Nothing + + +-------------------------------------------------------------------------------- +multiWayIfToAlignable + :: Config + -> Hs.LHsExpr Hs.GhcPs + -> [[Alignable S.RealSrcSpan]] +multiWayIfToAlignable conf (S.L _ (Hs.HsMultiIf _ grhss)) = + groupAlign (cMultiWayIf conf) as + where + as = fromMaybe [] $ traverse grhsToAlignable grhss +multiWayIfToAlignable _conf _ = [] + + +-------------------------------------------------------------------------------- +grhsToAlignable + :: S.Located (Hs.GRHS Hs.GhcPs (Hs.LHsExpr Hs.GhcPs)) + -> Maybe (Alignable S.RealSrcSpan) +grhsToAlignable (S.L grhsloc (Hs.GRHS _ guards@(_ : _) body)) = do + let guardsLocs = map S.getLoc guards + bodyLoc = S.getLoc body + left = foldl1' S.combineSrcSpans guardsLocs + matchPos <- toRealSrcSpan grhsloc + leftPos <- toRealSrcSpan left + bodyPos <- toRealSrcSpan bodyLoc + Just $ Alignable + { aContainer = matchPos + , aLeft = leftPos + , aRight = bodyPos + , aRightLead = length "-> " + } +grhsToAlignable (S.L _ (Hs.XGRHS x)) = Hs.noExtCon x +grhsToAlignable (S.L _ _) = Nothing -------------------------------------------------------------------------------- step :: Maybe Int -> Config -> Step -step maxColumns config = makeStep "Cases" $ \ls (module', _) -> - let module'' = fmap H.srcInfoSpan module' +step maxColumns config@(Config {..}) = makeStep "Cases" $ \ls module' -> + let changes + :: (S.Located (Hs.HsModule Hs.GhcPs) -> [a]) + -> (a -> [[Alignable S.RealSrcSpan]]) + -> [Change String] changes search toAlign = - [ change_ - | case_ <- search module'' - , aligns <- maybeToList (mapM toAlign case_) - , change_ <- align maxColumns aligns - ] + (concatMap . concatMap) (align maxColumns) . map toAlign $ search (parsedModule module') + configured :: [Change String] configured = concat $ - [ changes cases (altToAlignable H.mergeSrcSpan) - | cCases config - ] ++ - [changes tlpats matchToAlignable | cTopLevelPatterns config] ++ - [changes records fieldDeclToAlignable | cRecords config] - - in applyChanges configured ls + [changes records (recordToAlignable config)] ++ + [changes everything (matchGroupToAlignable config)] ++ + [changes everything (multiWayIfToAlignable config)] in + applyChanges configured ls diff --git a/lib/Language/Haskell/Stylish/Step/Squash.hs b/lib/Language/Haskell/Stylish/Step/Squash.hs index 0eb4895..23d1e9f 100644 --- a/lib/Language/Haskell/Stylish/Step/Squash.hs +++ b/lib/Language/Haskell/Stylish/Step/Squash.hs @@ -1,4 +1,7 @@ -------------------------------------------------------------------------------- +{-# LANGUAGE PartialTypeSignatures #-} +{-# LANGUAGE PatternGuards #-} +{-# LANGUAGE TypeFamilies #-} module Language.Haskell.Stylish.Step.Squash ( step ) where @@ -6,7 +9,8 @@ module Language.Haskell.Stylish.Step.Squash -------------------------------------------------------------------------------- import Data.Maybe (mapMaybe) -import qualified Language.Haskell.Exts as H +import qualified GHC.Hs as Hs +import qualified SrcLoc as S -------------------------------------------------------------------------------- @@ -17,46 +21,43 @@ import Language.Haskell.Stylish.Util -------------------------------------------------------------------------------- squash - :: (H.Annotated l, H.Annotated r) - => l H.SrcSpan -> r H.SrcSpan -> Maybe (Change String) -squash left right - | H.srcSpanEndLine lAnn == H.srcSpanStartLine rAnn = Just $ - changeLine (H.srcSpanEndLine lAnn) $ \str -> - let (pre, post) = splitAt (H.srcSpanEndColumn lAnn) str - in [trimRight pre ++ " " ++ trimLeft post] - | otherwise = Nothing - where - lAnn = H.ann left - rAnn = H.ann right - - --------------------------------------------------------------------------------- -squashFieldDecl :: H.FieldDecl H.SrcSpan -> Maybe (Change String) -squashFieldDecl (H.FieldDecl _ names type') + :: (S.HasSrcSpan l, S.HasSrcSpan r) + => l -> r -> Maybe (Change String) +squash left right = do + lAnn <- toRealSrcSpan $ S.getLoc left + rAnn <- toRealSrcSpan $ S.getLoc right + if S.srcSpanEndLine lAnn == S.srcSpanStartLine rAnn || + S.srcSpanEndLine lAnn + 1 == S.srcSpanStartLine rAnn + then Just $ + changeLine (S.srcSpanEndLine lAnn) $ \str -> + let (pre, post) = splitAt (S.srcSpanEndCol lAnn) str + in [trimRight pre ++ " " ++ trimLeft post] + else Nothing + + +-------------------------------------------------------------------------------- +squashFieldDecl :: Hs.ConDeclField Hs.GhcPs -> Maybe (Change String) +squashFieldDecl (Hs.ConDeclField _ names type' _) | null names = Nothing | otherwise = squash (last names) type' +squashFieldDecl (Hs.XConDeclField x) = Hs.noExtCon x -------------------------------------------------------------------------------- -squashMatch :: H.Match H.SrcSpan -> Maybe (Change String) -squashMatch (H.InfixMatch _ _ _ _ _ _) = Nothing -squashMatch (H.Match _ name pats rhs _) - | null pats = squash name rhs - | otherwise = squash (last pats) rhs - - --------------------------------------------------------------------------------- -squashAlt :: H.Alt H.SrcSpan -> Maybe (Change String) -squashAlt (H.Alt _ pat rhs _) = squash pat rhs +squashMatch :: Hs.Match Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> Maybe (Change String) +squashMatch (Hs.Match _ (Hs.FunRhs name _ _) [] grhss) = do + body <- unguardedRhsBody grhss + squash name body +squashMatch (Hs.Match _ _ pats grhss) = do + body <- unguardedRhsBody grhss + squash (last pats) body +squashMatch (Hs.XMatch x) = Hs.noExtCon x -------------------------------------------------------------------------------- step :: Step -step = makeStep "Squash" $ \ls (module', _) -> - let module'' = fmap H.srcInfoSpan module' - changes = concat - [ mapMaybe squashAlt (everything module'') - , mapMaybe squashMatch (everything module'') - , mapMaybe squashFieldDecl (everything module'') - ] - in applyChanges changes ls +step = makeStep "Squash" $ \ls (module') -> + let changes = + mapMaybe squashFieldDecl (everything module') ++ + mapMaybe squashMatch (everything module') in + applyChanges changes ls diff --git a/lib/Language/Haskell/Stylish/Step/UnicodeSyntax.hs b/lib/Language/Haskell/Stylish/Step/UnicodeSyntax.hs index 266e8e5..ff01dee 100644 --- a/lib/Language/Haskell/Stylish/Step/UnicodeSyntax.hs +++ b/lib/Language/Haskell/Stylish/Step/UnicodeSyntax.hs @@ -10,17 +10,17 @@ import Data.List (isPrefixOf, import Data.Map (Map) import qualified Data.Map as M import Data.Maybe (maybeToList) -import qualified Language.Haskell.Exts as H - - +import GHC.Hs.Binds +import GHC.Hs.Extension (GhcPs) +import GHC.Hs.Types -------------------------------------------------------------------------------- import Language.Haskell.Stylish.Block import Language.Haskell.Stylish.Editor +import Language.Haskell.Stylish.Module import Language.Haskell.Stylish.Step import Language.Haskell.Stylish.Step.LanguagePragmas (addLanguagePragma) import Language.Haskell.Stylish.Util - -------------------------------------------------------------------------------- unicodeReplacements :: Map String String unicodeReplacements = M.fromList @@ -39,7 +39,7 @@ replaceAll :: [(Int, [(Int, String)])] -> [Change String] replaceAll = map changeLine' where changeLine' (r, ns) = changeLine r $ \str -> return $ - applyChanges + applyChanges [ change (Block c ec) (const repl) | (c, needle) <- sort ns , let ec = c + length needle - 1 @@ -52,38 +52,17 @@ groupPerLine :: [((Int, Int), a)] -> [(Int, [(Int, a)])] groupPerLine = M.toList . M.fromListWith (++) . map (\((r, c), x) -> (r, [(c, x)])) - --------------------------------------------------------------------------------- -typeSigs :: H.Module H.SrcSpanInfo -> Lines -> [((Int, Int), String)] -typeSigs module' ls = - [ (pos, "::") - | H.TypeSig loc _ _ <- everything module' :: [H.Decl H.SrcSpanInfo] - , (start, end) <- infoPoints loc - , pos <- maybeToList $ between start end "::" ls - ] - - --------------------------------------------------------------------------------- -contexts :: H.Module H.SrcSpanInfo -> Lines -> [((Int, Int), String)] -contexts module' ls = - [ (pos, "=>") - | context <- everything module' :: [H.Context H.SrcSpanInfo] - , (start, end) <- infoPoints $ H.ann context - , pos <- maybeToList $ between start end "=>" ls +-- | Find symbol positions in the module. Currently only searches in type +-- signatures. +findSymbol :: Module -> Lines -> String -> [((Int, Int), String)] +findSymbol module' ls sym = + [ (pos, sym) + | TypeSig _ funLoc typeLoc <- everything (rawModuleDecls $ moduleDecls module') :: [Sig GhcPs] + , (funStart, _) <- infoPoints funLoc + , (_, typeEnd) <- infoPoints [hsSigWcType typeLoc] + , pos <- maybeToList $ between funStart typeEnd sym ls ] - --------------------------------------------------------------------------------- -typeFuns :: H.Module H.SrcSpanInfo -> Lines -> [((Int, Int), String)] -typeFuns module' ls = - [ (pos, "->") - | H.TyFun _ t1 t2 <- everything module' - , let start = H.srcSpanEnd $ H.srcInfoSpan $ H.ann t1 - , let end = H.srcSpanStart $ H.srcInfoSpan $ H.ann t2 - , pos <- maybeToList $ between start end "->" ls - ] - - -------------------------------------------------------------------------------- -- | Search for a needle in a haystack of lines. Only part the inside (startRow, -- startCol), (endRow, endCol) is searched. The return value is the position of @@ -110,11 +89,9 @@ step = (makeStep "UnicodeSyntax" .) . step' -------------------------------------------------------------------------------- step' :: Bool -> String -> Lines -> Module -> Lines -step' alp lg ls (module', _) = applyChanges changes ls +step' alp lg ls module' = applyChanges changes ls where changes = (if alp then addLanguagePragma lg "UnicodeSyntax" module' else []) ++ replaceAll perLine - perLine = sort $ groupPerLine $ - typeSigs module' ls ++ - contexts module' ls ++ - typeFuns module' ls + toReplace = [ "::", "=>", "->" ] + perLine = sort $ groupPerLine $ concatMap (findSymbol module' ls) toReplace diff --git a/lib/Language/Haskell/Stylish/Util.hs b/lib/Language/Haskell/Stylish/Util.hs index 9883f4b..1d35a03 100644 --- a/lib/Language/Haskell/Stylish/Util.hs +++ b/lib/Language/Haskell/Stylish/Util.hs @@ -1,8 +1,8 @@ -------------------------------------------------------------------------------- +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternGuards #-} module Language.Haskell.Stylish.Util - ( nameToString - , isOperator - , indent + ( indent , padRight , everything , infoPoints @@ -13,22 +13,35 @@ module Language.Haskell.Stylish.Util , wrapMaybe , wrapRestMaybe + -- * Extra list functions , withHead , withInit , withTail , withLast + , flagEnds + + , toRealSrcSpan + + , traceOutputable + , traceOutputableM + + , unguardedRhsBody + , rhsBody + + , getGuards ) where -------------------------------------------------------------------------------- -import Control.Arrow ((&&&), (>>>)) -import Data.Char (isAlpha, isSpace) +import Data.Char (isSpace) import Data.Data (Data) import qualified Data.Generics as G -import Data.Maybe (fromMaybe, listToMaybe, - maybeToList) +import Data.Maybe (maybeToList) import Data.Typeable (cast) -import qualified Language.Haskell.Exts as H +import Debug.Trace (trace) +import qualified GHC.Hs as Hs +import qualified Outputable +import qualified SrcLoc as S -------------------------------------------------------------------------------- @@ -36,18 +49,6 @@ import Language.Haskell.Stylish.Step -------------------------------------------------------------------------------- -nameToString :: H.Name l -> String -nameToString (H.Ident _ str) = str -nameToString (H.Symbol _ str) = str - - --------------------------------------------------------------------------------- -isOperator :: H.Name l -> Bool -isOperator = fromMaybe False - . (fmap (not . isAlpha) . listToMaybe) - . nameToString - --------------------------------------------------------------------------------- indent :: Int -> String -> String indent len = (indentPrefix len ++) @@ -68,8 +69,16 @@ everything = G.everything (++) (maybeToList . cast) -------------------------------------------------------------------------------- -infoPoints :: H.SrcSpanInfo -> [((Int, Int), (Int, Int))] -infoPoints = H.srcInfoPoints >>> map (H.srcSpanStart &&& H.srcSpanEnd) +infoPoints :: [S.Located pass] -> [((Int, Int), (Int, Int))] +infoPoints = fmap (helper . S.getLoc) + where + helper :: S.SrcSpan -> ((Int, Int), (Int, Int)) + helper (S.RealSrcSpan s) = do + let + start = S.realSrcSpanStart s + end = S.realSrcSpanEnd s + ((S.srcLocLine start, S.srcLocCol start), (S.srcLocLine end, S.srcLocCol end)) + helper _ = ((-1,-1), (-1,-1)) -------------------------------------------------------------------------------- @@ -117,7 +126,7 @@ noWrap :: String -- ^ Leading string -> Lines -- ^ Resulting lines noWrap leading _ind = noWrap' leading where - noWrap' ss [] = [ss] + noWrap' ss [] = [ss] noWrap' ss (str:strs) = noWrap' (ss ++ " " ++ str) strs @@ -181,7 +190,78 @@ withInit _ [] = [] withInit _ [x] = [x] withInit f (x : xs) = f x : withInit f xs + -------------------------------------------------------------------------------- withTail :: (a -> a) -> [a] -> [a] withTail _ [] = [] withTail f (x : xs) = x : map f xs + + + +-------------------------------------------------------------------------------- +-- | Utility for traversing through a list and knowing when you're at the +-- first and last element. +flagEnds :: [a] -> [(a, Bool, Bool)] +flagEnds = \case + [] -> [] + [x] -> [(x, True, True)] + x : y : zs -> (x, True, False) : go (y : zs) + where + go (x : y : zs) = (x, False, False) : go (y : zs) + go [x] = [(x, False, True)] + go [] = [] + + +-------------------------------------------------------------------------------- +traceOutputable :: Outputable.Outputable a => String -> a -> b -> b +traceOutputable title x = + trace (title ++ ": " ++ (Outputable.showSDocUnsafe $ Outputable.ppr x)) + + +-------------------------------------------------------------------------------- +traceOutputableM :: (Outputable.Outputable a, Monad m) => String -> a -> m () +traceOutputableM title x = traceOutputable title x $ pure () + + +-------------------------------------------------------------------------------- +-- take the (Maybe) RealSrcSpan out of the SrcSpan +toRealSrcSpan :: S.SrcSpan -> Maybe S.RealSrcSpan +toRealSrcSpan (S.RealSrcSpan s) = Just s +toRealSrcSpan _ = Nothing + + +-------------------------------------------------------------------------------- +-- Utility: grab the body out of guarded RHSs if it's a single unguarded one. +unguardedRhsBody :: Hs.GRHSs Hs.GhcPs a -> Maybe a +unguardedRhsBody (Hs.GRHSs _ [grhs] _) + | Hs.GRHS _ [] body <- S.unLoc grhs = Just body +unguardedRhsBody _ = Nothing + + +-- Utility: grab the body out of guarded RHSs +rhsBody :: Hs.GRHSs Hs.GhcPs a -> Maybe a +rhsBody (Hs.GRHSs _ [grhs] _) + | Hs.GRHS _ _ body <- S.unLoc grhs = Just body +rhsBody _ = Nothing + + +-------------------------------------------------------------------------------- +-- get guards in a guarded rhs of a Match +getGuards :: Hs.Match Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [Hs.GuardLStmt Hs.GhcPs] +getGuards (Hs.Match _ _ _ grhss) = + let + lgrhs = getLocGRHS grhss -- [] + grhs = map S.unLoc lgrhs + in + concatMap getGuardLStmts grhs +getGuards (Hs.XMatch x) = Hs.noExtCon x + + +getLocGRHS :: Hs.GRHSs Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [Hs.LGRHS Hs.GhcPs (Hs.LHsExpr Hs.GhcPs)] +getLocGRHS (Hs.GRHSs _ guardeds _) = guardeds +getLocGRHS (Hs.XGRHSs x) = Hs.noExtCon x + + +getGuardLStmts :: Hs.GRHS Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [Hs.GuardLStmt Hs.GhcPs] +getGuardLStmts (Hs.GRHS _ guards _) = guards +getGuardLStmts (Hs.XGRHS x) = Hs.noExtCon x |