{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}

-- |

-- Wrapper functions of 'Network.HTTP.Simple' and 'Network.HTTP.Client' to

-- add the 'User-Agent' HTTP request header to each request.


module Network.HTTP.StackClient
  ( httpJSON
  , httpLbs
  , httpNoBody
  , httpSink
  , withResponse
  , setRequestCheckStatus
  , setRequestMethod
  , setRequestHeader
  , addRequestHeader
  , setRequestBody
  , getResponseHeaders
  , getResponseBody
  , getResponseStatusCode
  , parseRequest
  , getUri
  , path
  , checkResponse
  , parseUrlThrow
  , requestHeaders
  , getGlobalManager
  , applyDigestAuth
  , displayDigestAuthException
  , Request
  , RequestBody(RequestBodyBS, RequestBodyLBS)
  , Response
  , HttpException
  , hAccept
  , hContentLength
  , hContentMD5
  , methodPut
  , formDataBody
  , partFileRequestBody
  , partBS
  , partLBS
  , setGitHubHeaders
  , download
  , redownload
  , verifiedDownload
  , verifiedDownloadWithProgress
  , CheckHexDigest (..)
  , DownloadRequest
  , drRetryPolicyDefault
  , VerifiedDownloadException (..)
  , HashCheck (..)
  , mkDownloadRequest
  , setHashChecks
  , setLengthCheck
  , setRetryPolicy
  , setForceDownload
  ) where

import           Control.Monad.State (get, put, modify)
import           Data.Aeson (FromJSON)
import qualified Data.ByteString as Strict
import           Data.Conduit (ConduitM, ConduitT, awaitForever, (.|), yield, await)
import           Data.Conduit.Lift (evalStateC)
import qualified Data.Conduit.List as CL
import           Data.Monoid (Sum (..))
import qualified Data.Text as T
import           Data.Time.Clock (NominalDiffTime, diffUTCTime, getCurrentTime)
import           Network.HTTP.Client (Request, RequestBody(..), Response, parseRequest, getUri, path, checkResponse, parseUrlThrow)
import           Network.HTTP.Simple (setRequestCheckStatus, setRequestMethod, setRequestBody, setRequestHeader, addRequestHeader, HttpException(..), getResponseBody, getResponseStatusCode, getResponseHeaders)
import           Network.HTTP.Types (hAccept, hContentLength, hContentMD5, methodPut)
import           Network.HTTP.Conduit (requestHeaders)
import           Network.HTTP.Client.TLS (getGlobalManager, applyDigestAuth, displayDigestAuthException)
import           Network.HTTP.Download hiding (download, redownload, verifiedDownload)
import qualified Network.HTTP.Download as Download
import qualified Network.HTTP.Simple
import           Network.HTTP.Client.MultipartFormData (formDataBody, partFileRequestBody, partBS, partLBS)
import           Path
import           Prelude (until, (!!))
import           RIO
import           RIO.PrettyPrint
import           Text.Printf (printf)


setUserAgent :: Request -> Request
setUserAgent :: Request -> Request
setUserAgent = HeaderName -> [ByteString] -> Request -> Request
setRequestHeader HeaderName
"User-Agent" [ByteString
"The Haskell Stack"]


httpJSON :: (MonadIO m, FromJSON a) => Request -> m (Response a)
httpJSON :: forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
Request -> m (Response a)
httpJSON = forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
Request -> m (Response a)
Network.HTTP.Simple.httpJSON forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


httpLbs :: MonadIO m => Request -> m (Response LByteString)
httpLbs :: forall (m :: * -> *).
MonadIO m =>
Request -> m (Response LByteString)
httpLbs = forall (m :: * -> *).
MonadIO m =>
Request -> m (Response LByteString)
Network.HTTP.Simple.httpLbs forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


httpNoBody :: MonadIO m => Request -> m (Response ())
httpNoBody :: forall (m :: * -> *). MonadIO m => Request -> m (Response ())
httpNoBody = forall (m :: * -> *). MonadIO m => Request -> m (Response ())
Network.HTTP.Simple.httpNoBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


httpSink
  :: MonadUnliftIO m
  => Request
  -> (Response () -> ConduitM Strict.ByteString Void m a)
  -> m a
httpSink :: forall (m :: * -> *) a.
MonadUnliftIO m =>
Request -> (Response () -> ConduitM ByteString Void m a) -> m a
httpSink = forall (m :: * -> *) a.
MonadUnliftIO m =>
Request -> (Response () -> ConduitM ByteString Void m a) -> m a
Network.HTTP.Simple.httpSink forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent


withResponse
  :: (MonadUnliftIO m, MonadIO n)
  => Request -> (Response (ConduitM i Strict.ByteString n ()) -> m a) -> m a
withResponse :: forall (m :: * -> *) (n :: * -> *) i a.
(MonadUnliftIO m, MonadIO n) =>
Request -> (Response (ConduitM i ByteString n ()) -> m a) -> m a
withResponse = forall (m :: * -> *) (n :: * -> *) i a.
(MonadUnliftIO m, MonadIO n) =>
Request -> (Response (ConduitM i ByteString n ()) -> m a) -> m a
Network.HTTP.Simple.withResponse forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Request
setUserAgent

-- | Set the user-agent request header

setGitHubHeaders :: Request -> Request
setGitHubHeaders :: Request -> Request
setGitHubHeaders = HeaderName -> [ByteString] -> Request -> Request
setRequestHeader HeaderName
"Accept" [ByteString
"application/vnd.github.v3+json"]

-- | Download the given URL to the given location. If the file already exists,

-- no download is performed. Otherwise, creates the parent directory, downloads

-- to a temporary file, and on file download completion moves to the

-- appropriate destination.

--

-- Throws an exception if things go wrong

download :: HasTerm env
         => Request
         -> Path Abs File -- ^ destination

         -> RIO env Bool -- ^ Was a downloaded performed (True) or did the file already exist (False)?

download :: forall env. HasTerm env => Request -> Path Abs File -> RIO env Bool
download Request
req Path Abs File
dest = forall env. HasTerm env => Request -> Path Abs File -> RIO env Bool
Download.download (Request -> Request
setUserAgent Request
req) Path Abs File
dest

-- | Same as 'download', but will download a file a second time if it is already present.

--

-- Returns 'True' if the file was downloaded, 'False' otherwise

redownload :: HasTerm env
           => Request
           -> Path Abs File -- ^ destination

           -> RIO env Bool
redownload :: forall env. HasTerm env => Request -> Path Abs File -> RIO env Bool
redownload Request
req Path Abs File
dest = forall env. HasTerm env => Request -> Path Abs File -> RIO env Bool
Download.redownload (Request -> Request
setUserAgent Request
req) Path Abs File
dest

-- | Copied and extended version of Network.HTTP.Download.download.

--

-- Has the following additional features:

-- * Verifies that response content-length header (if present)

--     matches expected length

-- * Limits the download to (close to) the expected # of bytes

-- * Verifies that the expected # bytes were downloaded (not too few)

-- * Verifies md5 if response includes content-md5 header

-- * Verifies the expected hashes

--

-- Throws VerifiedDownloadException.

-- Throws IOExceptions related to file system operations.

-- Throws HttpException.

verifiedDownload
         :: HasTerm env
         => DownloadRequest
         -> Path Abs File -- ^ destination

         -> (Maybe Integer -> ConduitM ByteString Void (RIO env) ()) -- ^ custom hook to observe progress

         -> RIO env Bool -- ^ Whether a download was performed

verifiedDownload :: forall env.
HasTerm env =>
DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
verifiedDownload DownloadRequest
dr Path Abs File
destpath Maybe Integer -> ConduitM ByteString Void (RIO env) ()
progressSink =
    forall env.
HasTerm env =>
DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
Download.verifiedDownload DownloadRequest
dr' Path Abs File
destpath Maybe Integer -> ConduitM ByteString Void (RIO env) ()
progressSink
  where
    dr' :: DownloadRequest
dr' = (Request -> Request) -> DownloadRequest -> DownloadRequest
modifyRequest Request -> Request
setUserAgent DownloadRequest
dr

verifiedDownloadWithProgress
  :: HasTerm env
  => DownloadRequest
  -> Path Abs File
  -> Text
  -> Maybe Int
  -> RIO env Bool
verifiedDownloadWithProgress :: forall env.
HasTerm env =>
DownloadRequest
-> Path Abs File -> Text -> Maybe Int -> RIO env Bool
verifiedDownloadWithProgress DownloadRequest
req Path Abs File
destpath Text
lbl Maybe Int
msize =
  forall env.
HasTerm env =>
DownloadRequest
-> Path Abs File
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ())
-> RIO env Bool
verifiedDownload DownloadRequest
req Path Abs File
destpath (forall env (m :: * -> *) f c.
(HasLogFunc env, MonadIO m, MonadReader env m) =>
Text -> Maybe Int -> f -> ConduitT ByteString c m ()
chattyDownloadProgress Text
lbl Maybe Int
msize)

chattyDownloadProgress
  :: ( HasLogFunc env
     , MonadIO m
     , MonadReader env m
     )
  => Text
  -> Maybe Int
  -> f
  -> ConduitT ByteString c m ()
chattyDownloadProgress :: forall env (m :: * -> *) f c.
(HasLogFunc env, MonadIO m, MonadReader env m) =>
Text -> Maybe Int -> f -> ConduitT ByteString c m ()
chattyDownloadProgress Text
label Maybe Int
mtotalSize f
_ = do
    ()
_ <- forall (m :: * -> *) env.
(MonadIO m, HasCallStack, MonadReader env m, HasLogFunc env) =>
Utf8Builder -> m ()
logSticky forall a b. (a -> b) -> a -> b
$ forall a. Display a => a -> Utf8Builder
RIO.display Text
label forall a. Semigroup a => a -> a -> a
<> Utf8Builder
": download has begun"
    forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
CL.map (forall a. a -> Sum a
Sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
Strict.length)
      forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall a (m :: * -> *).
(Monoid a, Semigroup a, MonadIO m) =>
NominalDiffTime -> ConduitM a a m ()
chunksOverTime NominalDiffTime
1
      forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall {o}. ConduitT (Sum Int) o m ()
go
  where
    go :: ConduitT (Sum Int) o m ()
go = forall (m :: * -> *) s i o r.
Monad m =>
s -> ConduitT i o (StateT s m) r -> ConduitT i o m r
evalStateC Int
0 forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever forall a b. (a -> b) -> a -> b
$ \(Sum Int
size) -> do
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+ Int
size)
        Int
totalSoFar <- forall s (m :: * -> *). MonadState s m => m s
get
        forall (m :: * -> *) env.
(MonadIO m, HasCallStack, MonadReader env m, HasLogFunc env) =>
Utf8Builder -> m ()
logSticky forall a b. (a -> b) -> a -> b
$ forall a. IsString a => String -> a
fromString forall a b. (a -> b) -> a -> b
$
            case Maybe Int
mtotalSize of
                Maybe Int
Nothing -> forall {t} {a}. (PrintfType t, Integral a) => a -> t
chattyProgressNoTotal Int
totalSoFar
                Just Int
0 -> forall {t} {a}. (PrintfType t, Integral a) => a -> t
chattyProgressNoTotal Int
totalSoFar
                Just Int
totalSize -> forall {a} {a} {t}.
(Integral a, Integral a, PrintfType t) =>
a -> a -> t
chattyProgressWithTotal Int
totalSoFar Int
totalSize

    -- Example: ghc: 42.13 KiB downloaded...

    chattyProgressNoTotal :: a -> t
chattyProgressNoTotal a
totalSoFar =
        forall r. PrintfType r => String -> r
printf (String
"%s: " forall a. Semigroup a => a -> a -> a
<> forall a. Integral a => String -> a -> String
bytesfmt String
"%7.2f" a
totalSoFar forall a. Semigroup a => a -> a -> a
<> String
" downloaded...")
                (Text -> String
T.unpack Text
label)

    -- Example: ghc: 50.00 MiB / 100.00 MiB (50.00%) downloaded...

    chattyProgressWithTotal :: a -> a -> t
chattyProgressWithTotal a
totalSoFar a
total =
      forall r. PrintfType r => String -> r
printf (String
"%s: " forall a. Semigroup a => a -> a -> a
<>
              forall a. Integral a => String -> a -> String
bytesfmt String
"%7.2f" a
totalSoFar forall a. Semigroup a => a -> a -> a
<> String
" / " forall a. Semigroup a => a -> a -> a
<>
              forall a. Integral a => String -> a -> String
bytesfmt String
"%.2f" a
total forall a. Semigroup a => a -> a -> a
<>
              String
" (%6.2f%%) downloaded...")
              (Text -> String
T.unpack Text
label)
              Double
percentage
      where percentage :: Double
            percentage :: Double
percentage = forall a b. (Integral a, Num b) => a -> b
fromIntegral a
totalSoFar forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
total forall a. Num a => a -> a -> a
* Double
100

-- | Given a printf format string for the decimal part and a number of

-- bytes, formats the bytes using an appropriate unit and returns the

-- formatted string.

--

-- >>> bytesfmt "%.2" 512368

-- "500.359375 KiB"

bytesfmt :: Integral a => String -> a -> String
bytesfmt :: forall a. Integral a => String -> a -> String
bytesfmt String
formatter a
bs = forall r. PrintfType r => String -> r
printf (String
formatter forall a. Semigroup a => a -> a -> a
<> String
" %s")
                               (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Num a => a -> a
signum a
bs) forall a. Num a => a -> a -> a
* Double
dec :: Double)
                               ([String]
bytesSuffixes forall a. [a] -> Int -> a
!! Int
i)
  where
    (Double
dec,Int
i) = forall {a} {a}. (Fractional a, Integral a, Ord a) => a -> (a, Int)
getSuffix (forall a. Num a => a -> a
abs a
bs)
    getSuffix :: a -> (a, Int)
getSuffix a
n = forall a. (a -> Bool) -> (a -> a) -> a -> a
until forall {a}. (Ord a, Num a) => (a, Int) -> Bool
p (\(a
x,Int
y) -> (a
x forall a. Fractional a => a -> a -> a
/ a
1024, Int
yforall a. Num a => a -> a -> a
+Int
1)) (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n,Int
0)
      where p :: (a, Int) -> Bool
p (a
n',Int
numDivs) = a
n' forall a. Ord a => a -> a -> Bool
< a
1024 Bool -> Bool -> Bool
|| Int
numDivs forall a. Eq a => a -> a -> Bool
== (forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
bytesSuffixes forall a. Num a => a -> a -> a
- Int
1)
    bytesSuffixes :: [String]
    bytesSuffixes :: [String]
bytesSuffixes = [String
"B",String
"KiB",String
"MiB",String
"GiB",String
"TiB",String
"PiB",String
"EiB",String
"ZiB",String
"YiB"]

-- Await eagerly (collect with monoidal append),

-- but space out yields by at least the given amount of time.

-- The final yield may come sooner, and may be a superfluous mempty.

-- Note that Integer and Float literals can be turned into NominalDiffTime

-- (these literals are interpreted as "seconds")

chunksOverTime :: (Monoid a, Semigroup a, MonadIO m) => NominalDiffTime -> ConduitM a a m ()
chunksOverTime :: forall a (m :: * -> *).
(Monoid a, Semigroup a, MonadIO m) =>
NominalDiffTime -> ConduitM a a m ()
chunksOverTime NominalDiffTime
diff = do
    UTCTime
currentTime <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    forall (m :: * -> *) s i o r.
Monad m =>
s -> ConduitT i o (StateT s m) r -> ConduitT i o m r
evalStateC (UTCTime
currentTime, forall a. Monoid a => a
mempty) ConduitT a a (StateT (UTCTime, a) m) ()
go
  where
    -- State is a tuple of:

    -- * the last time a yield happened (or the beginning of the sink)

    -- * the accumulated awaits since the last yield

    go :: ConduitT a a (StateT (UTCTime, a) m) ()
go = forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Maybe a
Nothing -> do
        (UTCTime
_, a
acc) <- forall s (m :: * -> *). MonadState s m => m s
get
        forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield a
acc
      Just a
a -> do
        (UTCTime
lastTime, a
acc) <- forall s (m :: * -> *). MonadState s m => m s
get
        let acc' :: a
acc' = a
acc forall a. Semigroup a => a -> a -> a
<> a
a
        UTCTime
currentTime <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
        if NominalDiffTime
diff forall a. Ord a => a -> a -> Bool
< UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
currentTime UTCTime
lastTime
          then forall s (m :: * -> *). MonadState s m => s -> m ()
put (UTCTime
currentTime, forall a. Monoid a => a
mempty) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield a
acc'
          else forall s (m :: * -> *). MonadState s m => s -> m ()
put (UTCTime
lastTime,    a
acc')
        ConduitT a a (StateT (UTCTime, a) m) ()
go