-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLib.hs
More file actions
430 lines (335 loc) · 10.9 KB
/
Lib.hs
File metadata and controls
430 lines (335 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE GADTs #-}
module Lib where
import Numeric
import Prelude hiding (and,or,(>))
import Control.Monad.RWS
import System.Process
import System.Exit
import Data.List (intercalate)
import qualified Data.Map as M
import qualified Data.Vector.Unboxed as V
import Statistics.Sample.KernelDensity
import Text.Read (readMaybe)
--------------------------------------------
-- Basic code-gen infrastructure
type a ⊸ b = a -> b
data Expr t where
Expr :: String -> Expr t -- constant expression
Lam :: (Expr a -> Expr t) -> Expr (a -> t) -- lambda
App :: Expr (a -> t) -> Expr a -> Expr t -- application
BinOp :: (String -> String -> String) -> Expr a -> Expr b -> Expr c
UnOp :: (String -> String) -> Expr a -> Expr b
If :: Prop -> Expr a -> Expr a -> Expr a
-- Let :: P (Expr a) -> (Expr a ⊸ Expr b) -> Expr b
Thunk :: P (Expr t) -> Expr t
class Representable a where
constant :: a -> Expr a
instance Representable Bool where
constant True = Expr "true"
constant False = Expr "false"
instance Representable Float where
constant = Expr . show
instance Fractional (Expr Float) where
fromRational x = constant (fromRat x)
(/) = BinOp (infixOp "/")
instance Num (Expr Float) where
abs = UnOp (\x -> "Math.abs(" ++ x ++ ")")
(-) = BinOp (infixOp "-")
(+) = BinOp (infixOp "+")
fromInteger x = constant (fromInteger x)
(*) = BinOp (infixOp "*")
instance Floating (Expr Float) where
sqrt = UnOp (\x -> "Math.sqrt(" ++ x ++ ")")
exp = UnOp (\x -> "Math.exp(" ++ x ++ ")")
log = UnOp (\x -> "Math.log(" ++ x ++ ")")
render :: Expr a -> P String
render = \case
UnOp op x -> do
x' <- render x
return (op x')
BinOp op x y -> do
x' <- render x
y' <- render y
return (op x' y')
Expr x -> return x
App (Lam f) x -> render (f x)
Lam f -> do
x <- newVar
body <- render (f (Expr x))
return ("function( " ++ x ++ ") { return " ++ body ++ "; }")
Thunk f -> do
(x,body) <- censor (\_ -> []) $ listen f
x' <- render x
return ("function() {\n" ++ unlines body ++ " return " ++ x' ++ ";\n }")
App x y -> do
x' <- render x
y' <- render y
return ("(" ++ x' ++ ")(" ++ y' ++ ")")
If cond x y -> do
c' <- render cond
x' <- render x
y' <- render y
return ("(" ++ c' ++ "?" ++ x' ++ ":" ++ y' ++ ")")
-- Let e f -> do
-- x <- newVar
-- e' <- render =<< e -- works only if f uses its argument at most once.
-- fx <- render (f (Expr x))
-- return ("var " ++ x ++ " = "++ e' ++";\n" ++ fx)
infixl #
(#) :: Expr (a -> b) -> Expr a -> Expr b
(#) = App
parens :: [Char] -> [Char]
parens x = "("++ x ++")"
infixOp :: [Char] -> [Char] -> [Char] -> [Char]
infixOp op x y = parens x ++ op ++ parens y
binFunc :: [Char] -> [Char] -> [Char] -> [Char]
binFunc f x y = f ++ "("++ x ++ "," ++ y ++")"
unFunc :: [Char] -> [Char] -> [Char]
unFunc f x = f ++ "("++ x ++ ")"
newtype P a = P (RWS () [String] State a) deriving (Functor,Applicative,Monad,MonadState State, MonadWriter [String])
data State = State {nextVar :: Int}
logtrace :: String -> Expr a -> P ()
logtrace msg x = do
x' <- render x
emit ("console.log(" ++ show "LOG: " ++ "+" ++ show msg ++ "+ \" \" +" ++ x' ++ ");")
compileModel :: String -> P (Expr a) -> String
compileModel mainFunc m = unlines (z ++ [x])
where (x,_,z) = runRWS p () (State {nextVar = 0})
(P p) = render (UnOp (unFunc mainFunc) (Thunk m))
and :: Prop -> Prop -> Prop
and = BinOp (infixOp "&&")
(∧) :: Prop -> Prop -> Prop
(∧) = and
iff :: Prop -> Prop -> Prop
iff = BinOp (infixOp "==")
or :: Prop -> Prop -> Prop
or = BinOp (infixOp "||")
not' :: Prop -> Prop
not' = UnOp (\x -> "(!(" ++ x ++ "))")
(-->) :: Prop -> Prop -> Prop
p --> q = not' p `or` q
-- | Allocate a new variable
newVar :: P String
newVar = do
n <- gets nextVar
modify $ \State{..} -> State {nextVar = n+1, ..}
return ("v" ++ show n)
emit :: String -> P ()
emit x = tell [x]
-----------------------------------
-- Types
type Vec = [Expr Float]
data Distrib a
type Ind = Vec
type Mat = [Vec]
-- type Vec = Expr Vector
type Prop = Expr Bool
type Pred = Ind -> Prop
type Measure = Ind -> Expr Float
type Adj = Vec
type AP = Measure
type CN = Ind -> Prop
type VP = Ind -> Prop
type NP = VP -> Prop
type Quant = CN -> NP
----------------------------------------------------
-- Compositional semantics
observe :: Prop -> P ()
observe = hyp
squared :: Num a => a -> a
squared x = x*x
observeEqual :: Expr Float -> Expr Float -> P ()
observeEqual x y = do
f <- render (negate (squared (x-y)))
emit ("factor(" ++ f ++ ");")
hyp :: Prop -> P ()
hyp x = do
x' <- render x
emit ("hyp(" ++ x' ++ ");")
-- | Sample new individual
newInd :: P Ind
newInd = newIndSuch []
-- | Sample new individual which satisfies some additional predicates
newIndSuch :: [Pred] -> P Ind
newIndSuch hs = do
x <- newVector
forM_ hs $ \h -> hyp (h x)
return x
numberOfDimensions :: Int
numberOfDimensions = 2
newVector :: P Vec
newVector = mapM (uncurry sampleGaussian) (replicate numberOfDimensions (0,1))
newNormedVector :: P Vec
newNormedVector = do
xs <- mapM (uncurry sampleGaussian) (replicate numberOfDimensions (0,1))
return ((/ norm xs) <$> xs)
cosineDistance :: Vec -> Vec -> Expr Float
cosineDistance x y = dotProd x y / (norm x * norm y)
norm2 :: Vec -> Expr Float
norm2 x = dotProd x x
norm :: Vec -> Expr Float
norm = sqrt . norm2
newMatrix = mapM (\_ -> newVector) (replicate numberOfDimensions ())
newMatrix :: P Mat
satisfyAll :: [Ind] -> [Pred] -> P ()
satisfyAll xs ps = forM_ xs $ \x -> forM ps $ \p -> hyp (p x)
dotProd :: Vec -> Vec -> Expr Float
dotProd x y = sum (zipWith (*) x y)
vecMatProd :: Vec -> Mat -> Vec
vecMatProd v = map (dotProd v)
type Scalar = (Vec,Expr Float)
newScalarA :: P Scalar
newScalarA = do
bias <- sampleGaussian 0 1
v <- newNormedVector
return (v,bias)
sampleGaussian :: Expr Float -> Expr Float -> P (Expr Float)
sampleGaussian mu sigma = do
v <- newVar
let m = Expr v
mu' <- render mu
sigma' <- render sigma
emit ("var " ++ v ++ " = gaussian(" ++ mu' ++ "," ++ sigma' ++ ");")
return m
newClass :: P Mat
newClass = newMatrix
forClass :: Mat -> Adj -> AP
forClass cl adj x = dotProd (vecMatProd adj cl) x
-- -- Alternative for scalar adjectives. We take the measure to be greater than that of a "random" element of the class.
-- forClassScalar :: Mat -> Adj -> AP
-- forClassScalar cl a x = Let (newIndSuch [isClass cl]) (\y -> adjAP a x - adjAP a y)
isClass :: Mat -> Pred
isClass clas x = dotProd (head clas) x > 0
adjAP :: Adj -> AP
adjAP = dotProd
gaussian :: Expr Float -> Expr Float -> Expr Float
gaussian mean stdDev = BinOp (binFunc "gaussian") mean stdDev
vague :: Float -> Measure -> Measure
vague vagueness m x = m x + gaussian 0 (Expr (show vagueness))
newMeasure :: P Measure
newMeasure = do
(v,bias) <- newScalarA
return (\x -> bias + dotProd v x)
positive :: Expr Float -> Prop
positive x = greaterThan x 0
newPred :: P (Ind -> Prop)
newPred = do
m <- newMeasure
return (\x -> positive (m x))
sample :: Expr (Distrib Bool) -> Expr Bool
sample d = Expr "sample" # d
bernouilli :: Expr Float -> Expr (Distrib Bool)
bernouilli = UnOp (\x -> "Bernoulli({p:" ++ x ++ "})")
-- TODO: there is a choice here.
-- this is an "intuitionistic" probably (it does not exclude the strong implication)
probablyInt :: Expr Float -> Prop -> Prop
probablyInt v x = sample (bernouilli v) --> x
-- this is a "definite" probably (it excludes the strong implication)
probablyDef :: Expr Float -> Prop -> Prop
probablyDef v x = If (sample (bernouilli v)) x (not' x)
-- "a --> b"
-- and "if a then b else (not b)"
-- are not the same!
expectedPortion :: P Prop -> Expr Float
expectedPortion = UnOp (unFunc "expectedPortion") . Thunk
genQProb :: Expr Float -> Quant
genQProb prob cn vp = expectedPortion p > prob
where p = do x <- newInd
observe (cn x)
return (vp x)
many :: Quant
many = genQProb 0.6
most :: Quant
most = genQProb 0.7
few :: Quant
few cn vp = genQProb 0.8 cn (not' . vp)
some :: Quant
some = genQProb 0.1
every :: Quant
every = genQProb 0.99
is :: Measure -> Pred
is m x = (m x) > 0
more :: Measure -> Ind -> Ind -> Prop
more m x y = (m x) > (m y)
greaterThan :: Expr Float -> Expr Float -> Prop
greaterThan = BinOp (infixOp ">")
(>) :: Expr Float -> Expr Float -> Prop
(>) = greaterThan
-- equal :: Expr Float -> Expr Float -> Prop
-- equal x y = 0.1 > (abs (x - y))
equal :: Expr Float -> Expr Float -> Prop
equal x y = 0.1 > (abs (x - y))
disjoint :: Pred -> Pred -> Pred
disjoint p q x = not' (p x) `or` not' (q x)
subsective :: Adj -> Mat -> Pred
subsective a cl x = isClass cl x `and` is (forClass cl a) x
-- An alternative semantics for subsective scalars.
anything :: Pred
anything _ = Expr "true"
plot :: String -> [Double] -> IO ()
plot prefix xs = do
let (xs',ys') = kde 64 (V.fromList xs)
fname = prefix ++ ".svg"
let ls =
["set terminal svg size 350,262 fname 'Verdana' enhanced background rgb 'white'",
"set output '" ++ fname ++ "'",
"set key off", -- no legend
"$data << EOD"] ++
[show x ++ " " ++ show y | (x,y) <- zip (V.toList xs') (V.toList ys')] ++
["EOD", "plot '$data' with lines"] -- filledcurve
putStrLn "Plotting results..."
(code,output,errors) <- readProcessWithExitCode "gnuplot" ["-p"] (unlines ls)
case code of
ExitFailure _ -> do
putStrLn "Gnuplot failed with input:"
putStrLn (unlines ls)
putStrLn "errors:"
putStrLn output
putStrLn errors
ExitSuccess ->
putStrLn ("Plot output to " ++ fname)
return ()
class KnownTyp a where
isContinuous :: Bool
instance KnownTyp Float where
isContinuous = True
instance KnownTyp Bool where
isContinuous = False
parseValues :: [String] -> IO [Double]
parseValues [] = return []
parseValues (v:vs) = case readMaybe v of
Nothing -> putStrLn v >> parseValues vs
Just x -> (x:) <$> parseValues vs
run :: forall a. KnownTyp a => P (Expr a) -> IO ()
run model = do
putStrLn "Creating model..."
rts <- readFile "../Frontend/RTS.wppl"
let m = compileModel mainName model
funname = "modelFunction"
mainName = if isContinuous @a then "mainContinuous" else "mainDiscrete"
fname = funname ++ ".wppl"
writeFile fname (intercalate "\n\n" [rts,m])
putStrLn "Running webppl..."
(code,output,errors) <- readProcessWithExitCode "webppl" [fname] ""
case code of
ExitFailure _ -> do
putStrLn "Webppl failed with errors:"
putStrLn output
putStrLn errors
ExitSuccess -> do
putStrLn "Success!"
case isContinuous @a of
True -> do
values <- parseValues (lines output)
plot funname values
False -> putStrLn output