{-# LANGUAGE CPP #-}
module PureSAT.Trail where

#define ASSERTING(x)

import Data.Primitive.PrimVar   (PrimVar, newPrimVar, readPrimVar, writePrimVar)

import PureSAT.Base
import PureSAT.Prim
import PureSAT.Clause2
import PureSAT.Level
import PureSAT.LitTable
import PureSAT.LitVar
import PureSAT.Utils

-------------------------------------------------------------------------------
-- Trail
-------------------------------------------------------------------------------

data Trail s = Trail !(PrimVar s Int) !(MutablePrimArray s Lit)

newTrail :: Int -> ST s (Trail s)
newTrail :: forall s. Int -> ST s (Trail s)
newTrail Int
capacity = do
    PrimVar s Int
size <- Int -> ST s (PrimVar (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
0
    MutablePrimArray s Lit
ls <- Int -> ST s (MutablePrimArray (PrimState (ST s)) Lit)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
capacity
    Trail s -> ST s (Trail s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimVar s Int -> MutablePrimArray s Lit -> Trail s
forall s. PrimVar s Int -> MutablePrimArray s Lit -> Trail s
Trail PrimVar s Int
size MutablePrimArray s Lit
ls)

cloneTrail :: Trail s -> ST s (Trail s)
cloneTrail :: forall s. Trail s -> ST s (Trail s)
cloneTrail (Trail PrimVar s Int
size MutablePrimArray s Lit
ls) = do
    Int
capacity <- MutablePrimArray (PrimState (ST s)) Lit -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Lit
MutablePrimArray (PrimState (ST s)) Lit
ls
    Int
n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    PrimVar s Int
size' <- Int -> ST s (PrimVar (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
n
    MutablePrimArray s Lit
ls' <- Int -> ST s (MutablePrimArray (PrimState (ST s)) Lit)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
capacity
    MutablePrimArray s Lit
-> Int -> MutablePrimArray s Lit -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a
-> Int -> MutablePrimArray s a -> Int -> Int -> ST s ()
copyMutablePrimArray MutablePrimArray s Lit
ls' Int
0 MutablePrimArray s Lit
ls Int
0 Int
n
    Trail s -> ST s (Trail s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimVar s Int -> MutablePrimArray s Lit -> Trail s
forall s. PrimVar s Int -> MutablePrimArray s Lit -> Trail s
Trail PrimVar s Int
size' MutablePrimArray s Lit
ls')

extendTrail :: Trail s -> Int -> ST s (Trail s)
extendTrail :: forall s. Trail s -> Int -> ST s (Trail s)
extendTrail trail :: Trail s
trail@(Trail PrimVar s Int
size MutablePrimArray s Lit
ls) Int
newCapacity = do
    Int
oldCapacity <- MutablePrimArray (PrimState (ST s)) Lit -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Lit
MutablePrimArray (PrimState (ST s)) Lit
ls
    let capacity :: Int
capacity = Int -> Int
nextPowerOf2 (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
oldCapacity Int
newCapacity)
    if Int
capacity Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
oldCapacity
    then Trail s -> ST s (Trail s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Trail s
trail
    else do
        Int
n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
        PrimVar s Int
size' <- Int -> ST s (PrimVar (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
n
        MutablePrimArray s Lit
ls' <- Int -> ST s (MutablePrimArray (PrimState (ST s)) Lit)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
capacity
        MutablePrimArray s Lit
-> Int -> MutablePrimArray s Lit -> Int -> Int -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a
-> Int -> MutablePrimArray s a -> Int -> Int -> ST s ()
copyMutablePrimArray MutablePrimArray s Lit
ls' Int
0 MutablePrimArray s Lit
ls Int
0 Int
n
        Trail s -> ST s (Trail s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimVar s Int -> MutablePrimArray s Lit -> Trail s
forall s. PrimVar s Int -> MutablePrimArray s Lit -> Trail s
Trail PrimVar s Int
size' MutablePrimArray s Lit
ls')

indexTrail :: Trail s -> Int -> ST s Lit
indexTrail :: forall s. Trail s -> Int -> ST s Lit
indexTrail (Trail PrimVar s Int
_ MutablePrimArray s Lit
ls) Int
i = MutablePrimArray s Lit -> Int -> ST s Lit
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Lit
ls Int
i

popTrail :: Trail s -> ST s Lit
popTrail :: forall s. Trail s -> ST s Lit
popTrail (Trail PrimVar s Int
size MutablePrimArray s Lit
ls) = do
    Int
n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    ASSERTING(assertST "non empty trail" (n >= 1))
    PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    MutablePrimArray s Lit -> Int -> ST s Lit
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Lit
ls (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

pushTrail :: Lit -> Trail s -> ST s ()
pushTrail :: forall s. Lit -> Trail s -> ST s ()
pushTrail Lit
l (Trail PrimVar s Int
size MutablePrimArray s Lit
ls) = do
    Int
n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    MutablePrimArray s Lit -> Int -> Lit -> ST s ()
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s Lit
ls Int
n Lit
l

traceTrail :: forall s. LitTable s Clause2 -> Levels s -> Trail s -> ST s ()
traceTrail :: forall s. LitTable s Clause2 -> Levels s -> Trail s -> ST s ()
traceTrail LitTable s Clause2
reasons Levels s
levels (Trail PrimVar s Int
size MutablePrimArray s Lit
lits) = do
    Int
n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    [String]
out <- Int -> Int -> ST s [String]
go Int
0 Int
n
    String -> ST s ()
forall (f :: * -> *). Applicative f => String -> f ()
traceM (String -> ST s ()) -> String -> ST s ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ String
"=== Trail ===" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
out
  where
    go :: Int -> Int -> ST s [String]
    go :: Int -> Int -> ST s [String]
go Int
i Int
n
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
        = [String] -> ST s [String]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return [String
"=== ===== ==="]

        | Bool
otherwise
        = do
            Lit
l <- MutablePrimArray s Lit -> Int -> ST s Lit
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Lit
lits Int
i
            Level Int
d <- Levels s -> Lit -> ST s Level
forall s. Levels s -> Lit -> ST s Level
getLevel Levels s
levels Lit
l
            Clause2
c <- LitTable s Clause2 -> Lit -> ST s Clause2
forall s a. LitTable s a -> Lit -> ST s a
readLitTable LitTable s Clause2
reasons Lit
l
            [String]
ls <- Int -> Int -> ST s [String]
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
n
            if Clause2 -> Bool
isNullClause Clause2
c
            then [String] -> ST s [String]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Char -> ShowS
showChar Char
'@' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Show a => a -> ShowS
shows Int
d ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" Decided " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Lit -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 Lit
l) String
"" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
ls)
            else [String] -> ST s [String]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Char -> ShowS
showChar Char
'@' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Show a => a -> ShowS
shows Int
d ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" Deduced " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Lit -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 Lit
l ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
showChar Char
' ' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Clause2 -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 Clause2
c) String
"" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
ls)

assertEmptyTrail :: HasCallStack => Trail s -> ST s ()
assertEmptyTrail :: forall s. HasCallStack => Trail s -> ST s ()
assertEmptyTrail (Trail PrimVar s Int
size MutablePrimArray s Lit
_) = do
    Int
n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"n == 0" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()