{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
#if USE_DEFAULT_SIGNATURES
{-# LANGUAGE DefaultSignatures #-}
#endif
{-# LANGUAGE TypeFamilies #-}
module Data.StateVar
(
HasGetter(get)
, GettableStateVar, makeGettableStateVar
, HasSetter(($=)), ($=!)
, SettableStateVar(SettableStateVar), makeSettableStateVar
, HasUpdate(($~), ($~!))
, StateVar(StateVar), makeStateVar
, mapStateVar
) where
import Control.Concurrent.STM
import Control.Monad.IO.Class
import Data.IORef
import Data.Typeable
import Foreign.Ptr
import Foreign.Storable
#if MIN_VERSION_base(4,12,0)
import Data.Functor.Contravariant
#endif
data StateVar a = StateVar (IO a) (a -> IO ()) deriving Typeable
#if MIN_VERSION_base(4,12,0)
instance Contravariant SettableStateVar where
contramap f (SettableStateVar k) = SettableStateVar (k . f)
{-# INLINE contramap #-}
#endif
makeStateVar
:: IO a
-> (a -> IO ())
-> StateVar a
makeStateVar = StateVar
mapStateVar :: (b -> a) -> (a -> b) -> StateVar a -> StateVar b
mapStateVar ba ab (StateVar ga sa) = StateVar (fmap ab ga) (sa . ba)
{-# INLINE mapStateVar #-}
newtype SettableStateVar a = SettableStateVar (a -> IO ())
deriving Typeable
makeSettableStateVar
:: (a -> IO ())
-> SettableStateVar a
makeSettableStateVar = SettableStateVar
{-# INLINE makeSettableStateVar #-}
type GettableStateVar = IO
makeGettableStateVar
:: IO a
-> GettableStateVar a
makeGettableStateVar = id
{-# INLINE makeGettableStateVar #-}
infixr 2 $=, $=!
class HasSetter t a | t -> a where
($=) :: MonadIO m => t -> a -> m ()
($=!) :: (HasSetter t a, MonadIO m) => t -> a -> m ()
p $=! a = (p $=) $! a
{-# INLINE ($=!) #-}
instance HasSetter (SettableStateVar a) a where
SettableStateVar f $= a = liftIO (f a)
{-# INLINE ($=) #-}
instance HasSetter (StateVar a) a where
StateVar _ s $= a = liftIO $ s a
{-# INLINE ($=) #-}
instance Storable a => HasSetter (Ptr a) a where
p $= a = liftIO $ poke p a
{-# INLINE ($=) #-}
instance HasSetter (IORef a) a where
p $= a = liftIO $ writeIORef p a
{-# INLINE ($=) #-}
instance HasSetter (TVar a) a where
p $= a = liftIO $ atomically $ writeTVar p a
{-# INLINE ($=) #-}
infixr 2 $~, $~!
class HasSetter t b => HasUpdate t a b | t -> a b where
($~) :: MonadIO m => t -> (a -> b) -> m ()
#if USE_DEFAULT_SIGNATURES
default ($~) :: (MonadIO m, a ~ b, HasGetter t a) => t -> (a -> b) -> m ()
($~) = defaultUpdate
#endif
($~!) :: MonadIO m => t -> (a -> b) -> m ()
#if USE_DEFAULT_SIGNATURES
default ($~!) :: (MonadIO m, a ~ b, HasGetter t a) => t -> (a -> b) -> m ()
($~!) = defaultUpdateStrict
#endif
defaultUpdate :: (MonadIO m, a ~ b, HasGetter t a, HasSetter t a) => t -> (a -> b) -> m ()
defaultUpdate r f = liftIO $ do
a <- get r
r $= f a
defaultUpdateStrict :: (MonadIO m, a ~ b, HasGetter t a, HasSetter t a) => t -> (a -> b) -> m ()
defaultUpdateStrict r f = liftIO $ do
a <- get r
r $=! f a
instance HasUpdate (StateVar a) a a where
($~) = defaultUpdate
($~!) = defaultUpdateStrict
instance Storable a => HasUpdate (Ptr a) a a where
($~) = defaultUpdate
($~!) = defaultUpdateStrict
instance HasUpdate (IORef a) a a where
r $~ f = liftIO $ atomicModifyIORef r $ \a -> (f a,())
#if MIN_VERSION_base(4,6,0)
r $~! f = liftIO $ atomicModifyIORef' r $ \a -> (f a,())
#else
r $~! f = liftIO $ do
s <- atomicModifyIORef r $ \a -> let s = f a in (s, s)
s `seq` return ()
#endif
instance HasUpdate (TVar a) a a where
r $~ f = liftIO $ atomically $ do
a <- readTVar r
writeTVar r (f a)
r $~! f = liftIO $ atomically $ do
a <- readTVar r
writeTVar r $! f a
class HasGetter t a | t -> a where
get :: MonadIO m => t -> m a
instance HasGetter (StateVar a) a where
get (StateVar g _) = liftIO g
{-# INLINE get #-}
instance HasGetter (TVar a) a where
get = liftIO . atomically . readTVar
{-# INLINE get #-}
instance HasGetter (IO a) a where
get = liftIO
{-# INLINE get #-}
instance HasGetter (STM a) a where
get = liftIO . atomically
{-# INLINE get #-}
instance Storable a => HasGetter (Ptr a) a where
get = liftIO . peek
{-# INLINE get #-}
instance HasGetter (IORef a) a where
get = liftIO . readIORef
{-# INLINE get #-}