Bayesian Neural Network (VI) for classification (under Development)

In [1]:
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference
/Users/zhenwend/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Generate Synthetic Data

In [75]:
import GPy
%matplotlib inline
from pylab import *

np.random.seed(4)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(200,1)
y = np.random.multivariate_normal(mean=np.zeros((200,)), cov=k.K(x), size=(1,)).T>0.
plot(x[:,0], y[:,0], '.')
Out[75]:
[<matplotlib.lines.Line2D at 0x1a22348898>]
../../_images/examples_notebooks_bnn_classification_3_1.png
In [76]:
D = 10
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=1))
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=D))
    net.add(nn.Dense(2, flatten=False, in_units=D))
net.initialize(mx.init.Xavier(magnitude=1))
In [77]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
In [78]:
m = mf.components.Model()
m.N = mf.components.Variable()
m.f = mf.components.functions.MXFusionGluonFunction(net, nOutputs=1, broadcastable=False)
m.x = mf.components.Variable(shape=(m.N,1))
m.r = m.f(m.x)
for _,v in m.r.factor.block_variables:
    v.set_prior(mf.components.distributions.Normal(mean=mx.nd.array([0]),variance=mx.nd.array([3.])))
m.y = mf.components.distributions.Categorical.define_variable(log_prob=m.r,  shape=(m.N,1))
m.show()
Variable(45b58) ~ Normal(mean=Variable(738a7), variance=Variable(9ef9b))
Variable(d7aa2) ~ Normal(mean=Variable(b4564), variance=Variable(02ad0))
Variable(ace48) ~ Normal(mean=Variable(3989f), variance=Variable(49e6a))
Variable(35ae3) ~ Normal(mean=Variable(666b4), variance=Variable(5c0d8))
Variable(7a303) ~ Normal(mean=Variable(9d461), variance=Variable(dd20b))
Variable(28ccc) ~ Normal(mean=Variable(37c47), variance=Variable(a7de6))
r = GluonFunctionEvaluation(nn_dense0_weight=Variable(28ccc), nn_dense0_bias=Variable(7a303), nn_dense1_weight=Variable(35ae3), nn_dense1_bias=Variable(ace48), nn_dense2_weight=Variable(d7aa2), nn_dense2_bias=Variable(45b58), nn_input_0=x)
y ~ Categorical(log_prob=r)
In [79]:
from mxfusion.inference import BatchInferenceLoop, create_Gaussian_meanfield, GradBasedInference, StochasticVariationalInference, MAP
In [80]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=5, model=m, posterior=q, observed=observed)
# alg = MAP(model=m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
In [81]:
infr.initialize(y=mx.nd.array(y), x=mx.nd.array(x))
 /Users/zhenwend/mxfusion/src/MXFusion/mxfusion/inference/inference_parameters.py:52: UserWarning:InferenceParameters has already been initialized.  The existing one will be overwritten.
In [82]:
for v_name, v in m.r.factor.block_variables:
    uuid = v.uuid
    loc_uuid = infr.inference_algorithm.posterior[uuid].factor.variance.uuid
    a = infr.params.param_dict[loc_uuid].data().asnumpy()
    a[:] = 1e-8
    infr.params[infr.inference_algorithm.posterior[uuid].factor.mean] = net.collect_params()[v_name].data()
    infr.params[infr.inference_algorithm.posterior[uuid].factor.variance] = mx.nd.array(a)
In [83]:
infr.run(max_iter=500, learning_rate=1e-1, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
 /Users/zhenwend/mxfusion/src/MXFusion/mxfusion/inference/inference.py:111: UserWarning:Trying to initialize the inference twice, skipping.
Iteration 1 logL: -1544.5255126953125
Iteration 2 logL: -1539.4837646484375
Iteration 3 logL: -1511.021484375
Iteration 4 logL: -1505.3983154296875
Iteration 5 logL: -1494.5648193359375
Iteration 6 logL: -1491.2451171875
Iteration 7 logL: -1478.662841796875
Iteration 8 logL: -1478.6864013671875
Iteration 9 logL: -1471.1865234375
Iteration 10 logL: -1452.61962890625
Iteration 11 logL: -1451.615478515625
Iteration 12 logL: -1444.682373046875
Iteration 13 logL: -1445.5955810546875
Iteration 14 logL: -1430.4305419921875
Iteration 15 logL: -1418.5712890625
Iteration 16 logL: -1418.8111572265625
Iteration 17 logL: -1404.1358642578125
Iteration 18 logL: -1400.627685546875
Iteration 19 logL: -1381.745361328125
Iteration 20 logL: -1376.2139892578125
Iteration 21 logL: -1372.777099609375
Iteration 22 logL: -1366.9664306640625
Iteration 23 logL: -1361.6920166015625
Iteration 24 logL: -1342.0687255859375
Iteration 25 logL: -1341.3878173828125
Iteration 26 logL: -1335.53271484375
Iteration 27 logL: -1321.6600341796875
Iteration 28 logL: -1322.144287109375
Iteration 29 logL: -1304.44482421875
Iteration 30 logL: -1296.4647216796875
Iteration 31 logL: -1288.6043701171875
Iteration 32 logL: -1288.419189453125
Iteration 33 logL: -1273.767333984375
Iteration 34 logL: -1261.6514892578125
Iteration 35 logL: -1246.7862548828125
Iteration 36 logL: -1237.2008056640625
Iteration 37 logL: -1230.070556640625
Iteration 38 logL: -1217.3485107421875
Iteration 39 logL: -1210.28466796875
Iteration 40 logL: -1196.698486328125
Iteration 41 logL: -1179.746826171875
Iteration 42 logL: -1172.70556640625
Iteration 43 logL: -1153.209716796875
Iteration 44 logL: -1144.113037109375
Iteration 45 logL: -1131.1912841796875
Iteration 46 logL: -1122.3773193359375
Iteration 47 logL: -1108.426025390625
Iteration 48 logL: -1095.648193359375
Iteration 49 logL: -1082.5948486328125
Iteration 50 logL: -1079.716552734375
Iteration 51 logL: -1066.1383056640625
Iteration 52 logL: -1065.216064453125
Iteration 53 logL: -1048.493896484375
Iteration 54 logL: -1039.2891845703125
Iteration 55 logL: -1040.451416015625
Iteration 56 logL: -1024.9017333984375
Iteration 57 logL: -1007.782958984375
Iteration 58 logL: -1014.7578125
Iteration 59 logL: -991.4736328125
Iteration 60 logL: -991.7649536132812
Iteration 61 logL: -981.59716796875
Iteration 62 logL: -974.068603515625
Iteration 63 logL: -972.3157958984375
Iteration 64 logL: -960.890625
Iteration 65 logL: -958.5018310546875
Iteration 66 logL: -945.4229736328125
Iteration 67 logL: -937.4677124023438
Iteration 68 logL: -921.200439453125
Iteration 69 logL: -915.6636962890625
Iteration 70 logL: -912.5822143554688
Iteration 71 logL: -905.74267578125
Iteration 72 logL: -900.7443237304688
Iteration 73 logL: -883.1107177734375
Iteration 74 logL: -886.2918701171875
Iteration 75 logL: -879.94775390625
Iteration 76 logL: -861.6558837890625
Iteration 77 logL: -854.5107421875
Iteration 78 logL: -851.0504150390625
Iteration 79 logL: -846.66796875
Iteration 80 logL: -838.7744140625
Iteration 81 logL: -826.3006591796875
Iteration 82 logL: -817.10791015625
Iteration 83 logL: -806.6436767578125
Iteration 84 logL: -803.6990966796875
Iteration 85 logL: -794.893310546875
Iteration 86 logL: -792.5907592773438
Iteration 87 logL: -776.2742919921875
Iteration 88 logL: -775.6581420898438
Iteration 89 logL: -767.6826782226562
Iteration 90 logL: -755.8037109375
Iteration 91 logL: -750.8147583007812
Iteration 92 logL: -741.1118774414062
Iteration 93 logL: -739.661376953125
Iteration 94 logL: -724.867431640625
Iteration 95 logL: -721.9619750976562
Iteration 96 logL: -713.5444946289062
Iteration 97 logL: -700.5989379882812
Iteration 98 logL: -692.425537109375
Iteration 99 logL: -685.3265380859375
Iteration 100 logL: -682.6221923828125
Iteration 101 logL: -667.1160888671875
Iteration 102 logL: -671.22021484375
Iteration 103 logL: -657.33984375
Iteration 104 logL: -647.55712890625
Iteration 105 logL: -638.2933959960938
Iteration 106 logL: -632.623046875
Iteration 107 logL: -632.192626953125
Iteration 108 logL: -622.1544189453125
Iteration 109 logL: -612.350341796875
Iteration 110 logL: -607.607666015625
Iteration 111 logL: -607.3804931640625
Iteration 112 logL: -591.3848876953125
Iteration 113 logL: -579.00537109375
Iteration 114 logL: -574.9884033203125
Iteration 115 logL: -573.00390625
Iteration 116 logL: -558.756103515625
Iteration 117 logL: -556.7438354492188
Iteration 118 logL: -548.1097412109375
Iteration 119 logL: -527.03515625
Iteration 120 logL: -540.4559326171875
Iteration 121 logL: -524.711181640625
Iteration 122 logL: -530.5860595703125
Iteration 123 logL: -521.8486328125
Iteration 124 logL: -502.2843017578125
Iteration 125 logL: -501.84820556640625
Iteration 126 logL: -491.798828125
Iteration 127 logL: -509.27178955078125
Iteration 128 logL: -473.0491943359375
Iteration 129 logL: -490.3575439453125
Iteration 130 logL: -465.9736328125
Iteration 131 logL: -459.7388916015625
Iteration 132 logL: -450.2601318359375
Iteration 133 logL: -440.4290771484375
Iteration 134 logL: -436.97113037109375
Iteration 135 logL: -439.14788818359375
Iteration 136 logL: -414.6134948730469
Iteration 137 logL: -409.9062805175781
Iteration 138 logL: -396.1278076171875
Iteration 139 logL: -414.68865966796875
Iteration 140 logL: -395.8081359863281
Iteration 141 logL: -386.2045593261719
Iteration 142 logL: -378.982666015625
Iteration 143 logL: -368.39642333984375
Iteration 144 logL: -365.70941162109375
Iteration 145 logL: -368.5234680175781
Iteration 146 logL: -360.26531982421875
Iteration 147 logL: -352.19720458984375
Iteration 148 logL: -347.9356689453125
Iteration 149 logL: -353.3387756347656
Iteration 150 logL: -327.9718933105469
Iteration 151 logL: -345.4718017578125
Iteration 152 logL: -333.1092834472656
Iteration 153 logL: -302.427490234375
Iteration 154 logL: -317.42205810546875
Iteration 155 logL: -289.9781799316406
Iteration 156 logL: -296.4144592285156
Iteration 157 logL: -304.27593994140625
Iteration 158 logL: -306.234375
Iteration 159 logL: -284.4197998046875
Iteration 160 logL: -284.7406921386719
Iteration 161 logL: -274.4178466796875
Iteration 162 logL: -267.80206298828125
Iteration 163 logL: -244.25775146484375
Iteration 164 logL: -254.48374938964844
Iteration 165 logL: -242.91217041015625
Iteration 166 logL: -248.5866241455078
Iteration 167 logL: -250.1661376953125
Iteration 168 logL: -273.66912841796875
Iteration 169 logL: -240.9576416015625
Iteration 170 logL: -256.94140625
Iteration 171 logL: -275.43536376953125
Iteration 172 logL: -214.0106964111328
Iteration 173 logL: -214.67474365234375
Iteration 174 logL: -232.027587890625
Iteration 175 logL: -228.86508178710938
Iteration 176 logL: -240.1166534423828
Iteration 177 logL: -198.66574096679688
Iteration 178 logL: -216.43634033203125
Iteration 179 logL: -207.2248077392578
Iteration 180 logL: -203.01507568359375
Iteration 181 logL: -201.1699676513672
Iteration 182 logL: -247.60006713867188
Iteration 183 logL: -226.33963012695312
Iteration 184 logL: -253.77496337890625
Iteration 185 logL: -191.73776245117188
Iteration 186 logL: -247.42324829101562
Iteration 187 logL: -201.26058959960938
Iteration 188 logL: -177.60061645507812
Iteration 189 logL: -190.0717010498047
Iteration 190 logL: -317.8877868652344
Iteration 191 logL: -215.68914794921875
Iteration 192 logL: -184.3809356689453
Iteration 193 logL: -180.11679077148438
Iteration 194 logL: -205.02462768554688
Iteration 195 logL: -224.42221069335938
Iteration 196 logL: -244.3284912109375
Iteration 197 logL: -224.40420532226562
Iteration 198 logL: -205.4558868408203
Iteration 199 logL: -188.27389526367188
Iteration 200 logL: -218.32017517089844
Iteration 201 logL: -174.22677612304688
Iteration 202 logL: -199.7232208251953
Iteration 203 logL: -157.99331665039062
Iteration 204 logL: -192.910400390625
Iteration 205 logL: -200.99237060546875
Iteration 206 logL: -163.43582153320312
Iteration 207 logL: -213.77777099609375
Iteration 208 logL: -166.8415985107422
Iteration 209 logL: -161.51039123535156
Iteration 210 logL: -176.5122833251953
Iteration 211 logL: -187.05154418945312
Iteration 212 logL: -153.88722229003906
Iteration 213 logL: -188.95333862304688
Iteration 214 logL: -172.976318359375
Iteration 215 logL: -238.98878479003906
Iteration 216 logL: -174.09642028808594
Iteration 217 logL: -168.273681640625
Iteration 218 logL: -189.6393585205078
Iteration 219 logL: -216.9544677734375
Iteration 220 logL: -171.2532196044922
Iteration 221 logL: -174.9771270751953
Iteration 222 logL: -206.28012084960938
Iteration 223 logL: -191.87185668945312
Iteration 224 logL: -153.0070037841797
Iteration 225 logL: -181.36277770996094
Iteration 226 logL: -169.47288513183594
Iteration 227 logL: -156.15509033203125
Iteration 228 logL: -179.83694458007812
Iteration 229 logL: -183.0515594482422
Iteration 230 logL: -163.0454864501953
Iteration 231 logL: -184.2140350341797
Iteration 232 logL: -219.8138885498047
Iteration 233 logL: -162.65086364746094
Iteration 234 logL: -144.58377075195312
Iteration 235 logL: -179.9770965576172
Iteration 236 logL: -183.8132781982422
Iteration 237 logL: -199.29489135742188
Iteration 238 logL: -151.03273010253906
Iteration 239 logL: -160.94741821289062
Iteration 240 logL: -176.21104431152344
Iteration 241 logL: -188.37820434570312
Iteration 242 logL: -196.54141235351562
Iteration 243 logL: -162.23831176757812
Iteration 244 logL: -141.35182189941406
Iteration 245 logL: -152.73011779785156
Iteration 246 logL: -237.92613220214844
Iteration 247 logL: -210.1531219482422
Iteration 248 logL: -185.01890563964844
Iteration 249 logL: -185.64022827148438
Iteration 250 logL: -221.623046875
Iteration 251 logL: -216.6830596923828
Iteration 252 logL: -160.9233856201172
Iteration 253 logL: -172.77687072753906
Iteration 254 logL: -165.9558868408203
Iteration 255 logL: -201.72703552246094
Iteration 256 logL: -164.8269500732422
Iteration 257 logL: -165.69264221191406
Iteration 258 logL: -183.5315704345703
Iteration 259 logL: -269.22235107421875
Iteration 260 logL: -146.8050994873047
Iteration 261 logL: -212.0876922607422
Iteration 262 logL: -191.9491729736328
Iteration 263 logL: -218.36244201660156
Iteration 264 logL: -162.2360382080078
Iteration 265 logL: -175.89859008789062
Iteration 266 logL: -202.32484436035156
Iteration 267 logL: -160.58258056640625
Iteration 268 logL: -159.10276794433594
Iteration 269 logL: -164.344482421875
Iteration 270 logL: -211.7244873046875
Iteration 271 logL: -192.30337524414062
Iteration 272 logL: -158.49166870117188
Iteration 273 logL: -171.47390747070312
Iteration 274 logL: -183.94886779785156
Iteration 275 logL: -171.7067413330078
Iteration 276 logL: -151.35958862304688
Iteration 277 logL: -173.07603454589844
Iteration 278 logL: -154.04678344726562
Iteration 279 logL: -192.27493286132812
Iteration 280 logL: -192.57839965820312
Iteration 281 logL: -181.84156799316406
Iteration 282 logL: -185.34580993652344
Iteration 283 logL: -153.8499298095703
Iteration 284 logL: -152.19866943359375
Iteration 285 logL: -144.4599151611328
Iteration 286 logL: -165.732421875
Iteration 287 logL: -176.54833984375
Iteration 288 logL: -164.82205200195312
Iteration 289 logL: -161.06195068359375
Iteration 290 logL: -176.03353881835938
Iteration 291 logL: -151.64808654785156
Iteration 292 logL: -194.75949096679688
Iteration 293 logL: -175.03729248046875
Iteration 294 logL: -165.12147521972656
Iteration 295 logL: -158.97984313964844
Iteration 296 logL: -181.7843780517578
Iteration 297 logL: -162.631591796875
Iteration 298 logL: -177.43496704101562
Iteration 299 logL: -188.93605041503906
Iteration 300 logL: -175.7605743408203
Iteration 301 logL: -159.49876403808594
Iteration 302 logL: -156.8190460205078
Iteration 303 logL: -158.90673828125
Iteration 304 logL: -162.634765625
Iteration 305 logL: -197.76136779785156
Iteration 306 logL: -179.81996154785156
Iteration 307 logL: -160.31796264648438
Iteration 308 logL: -182.8149871826172
Iteration 309 logL: -166.40286254882812
Iteration 310 logL: -161.11827087402344
Iteration 311 logL: -158.80226135253906
Iteration 312 logL: -172.4361114501953
Iteration 313 logL: -171.3463592529297
Iteration 314 logL: -152.3768310546875
Iteration 315 logL: -158.77395629882812
Iteration 316 logL: -213.81350708007812
Iteration 317 logL: -145.558349609375
Iteration 318 logL: -181.20938110351562
Iteration 319 logL: -185.07322692871094
Iteration 320 logL: -207.5948028564453
Iteration 321 logL: -167.41371154785156
Iteration 322 logL: -170.78697204589844
Iteration 323 logL: -158.92723083496094
Iteration 324 logL: -168.30673217773438
Iteration 325 logL: -145.81129455566406
Iteration 326 logL: -171.64073181152344
Iteration 327 logL: -182.69529724121094
Iteration 328 logL: -154.89712524414062
Iteration 329 logL: -151.62460327148438
Iteration 330 logL: -142.93319702148438
Iteration 331 logL: -158.87815856933594
Iteration 332 logL: -133.32958984375
Iteration 333 logL: -152.91761779785156
Iteration 334 logL: -176.51593017578125
Iteration 335 logL: -176.9428253173828
Iteration 336 logL: -195.5206756591797
Iteration 337 logL: -146.91213989257812
Iteration 338 logL: -174.56198120117188
Iteration 339 logL: -157.5927734375
Iteration 340 logL: -163.74888610839844
Iteration 341 logL: -149.73818969726562
Iteration 342 logL: -141.78402709960938
Iteration 343 logL: -150.82537841796875
Iteration 344 logL: -203.70167541503906
Iteration 345 logL: -234.94883728027344
Iteration 346 logL: -138.6446990966797
Iteration 347 logL: -155.61611938476562
Iteration 348 logL: -170.1373291015625
Iteration 349 logL: -137.4384307861328
Iteration 350 logL: -159.34925842285156
Iteration 351 logL: -189.43511962890625
Iteration 352 logL: -146.205078125
Iteration 353 logL: -162.0035400390625
Iteration 354 logL: -152.1077117919922
Iteration 355 logL: -177.6432342529297
Iteration 356 logL: -162.13482666015625
Iteration 357 logL: -164.792236328125
Iteration 358 logL: -175.9607391357422
Iteration 359 logL: -175.19512939453125
Iteration 360 logL: -153.57888793945312
Iteration 361 logL: -149.2701873779297
Iteration 362 logL: -151.72816467285156
Iteration 363 logL: -151.6085662841797
Iteration 364 logL: -164.00177001953125
Iteration 365 logL: -169.85508728027344
Iteration 366 logL: -174.19601440429688
Iteration 367 logL: -150.11888122558594
Iteration 368 logL: -183.55679321289062
Iteration 369 logL: -175.8426055908203
Iteration 370 logL: -149.37100219726562
Iteration 371 logL: -185.88963317871094
Iteration 372 logL: -151.3091583251953
Iteration 373 logL: -161.74493408203125
Iteration 374 logL: -146.36911010742188
Iteration 375 logL: -159.70030212402344
Iteration 376 logL: -146.5471954345703
Iteration 377 logL: -183.76512145996094
Iteration 378 logL: -162.58302307128906
Iteration 379 logL: -154.88087463378906
Iteration 380 logL: -153.092529296875
Iteration 381 logL: -149.21633911132812
Iteration 382 logL: -159.79013061523438
Iteration 383 logL: -207.46981811523438
Iteration 384 logL: -166.4166717529297
Iteration 385 logL: -163.28904724121094
Iteration 386 logL: -139.99237060546875
Iteration 387 logL: -159.22752380371094
Iteration 388 logL: -147.1323699951172
Iteration 389 logL: -131.47760009765625
Iteration 390 logL: -173.95697021484375
Iteration 391 logL: -175.2364501953125
Iteration 392 logL: -164.6923828125
Iteration 393 logL: -170.76512145996094
Iteration 394 logL: -150.3489532470703
Iteration 395 logL: -149.53738403320312
Iteration 396 logL: -165.0814971923828
Iteration 397 logL: -159.6266326904297
Iteration 398 logL: -205.7377166748047
Iteration 399 logL: -201.95794677734375
Iteration 400 logL: -173.90322875976562
Iteration 401 logL: -219.3822479248047
Iteration 402 logL: -170.65428161621094
Iteration 403 logL: -154.12339782714844
Iteration 404 logL: -145.6951446533203
Iteration 405 logL: -154.26612854003906
Iteration 406 logL: -168.78125
Iteration 407 logL: -151.6161346435547
Iteration 408 logL: -137.94374084472656
Iteration 409 logL: -162.55181884765625
Iteration 410 logL: -195.5110626220703
Iteration 411 logL: -181.4525146484375
Iteration 412 logL: -176.29464721679688
Iteration 413 logL: -161.6767578125
Iteration 414 logL: -147.3198699951172
Iteration 415 logL: -174.78868103027344
Iteration 416 logL: -241.66427612304688
Iteration 417 logL: -194.82431030273438
Iteration 418 logL: -163.0545196533203
Iteration 419 logL: -142.1942138671875
Iteration 420 logL: -161.17933654785156
Iteration 421 logL: -151.19564819335938
Iteration 422 logL: -270.526123046875
Iteration 423 logL: -174.3603973388672
Iteration 424 logL: -145.0341339111328
Iteration 425 logL: -174.63157653808594
Iteration 426 logL: -134.2644805908203
Iteration 427 logL: -198.86883544921875
Iteration 428 logL: -149.4174346923828
Iteration 429 logL: -144.0459442138672
Iteration 430 logL: -161.4966583251953
Iteration 431 logL: -134.34768676757812
Iteration 432 logL: -134.56263732910156
Iteration 433 logL: -180.6781463623047
Iteration 434 logL: -158.80487060546875
Iteration 435 logL: -159.59298706054688
Iteration 436 logL: -139.64794921875
Iteration 437 logL: -133.45777893066406
Iteration 438 logL: -152.6923370361328
Iteration 439 logL: -156.44227600097656
Iteration 440 logL: -157.8925018310547
Iteration 441 logL: -167.68077087402344
Iteration 442 logL: -161.24176025390625
Iteration 443 logL: -159.16549682617188
Iteration 444 logL: -172.21835327148438
Iteration 445 logL: -156.8822479248047
Iteration 446 logL: -133.2530975341797
Iteration 447 logL: -153.4013214111328
Iteration 448 logL: -133.96762084960938
Iteration 449 logL: -138.84234619140625
Iteration 450 logL: -183.0357666015625
Iteration 451 logL: -170.00534057617188
Iteration 452 logL: -157.46279907226562
Iteration 453 logL: -155.4312286376953
Iteration 454 logL: -178.26963806152344
Iteration 455 logL: -136.466552734375
Iteration 456 logL: -142.8925018310547
Iteration 457 logL: -159.1484832763672
Iteration 458 logL: -173.41558837890625
Iteration 459 logL: -161.4503936767578
Iteration 460 logL: -165.84906005859375
Iteration 461 logL: -173.14651489257812
Iteration 462 logL: -155.19857788085938
Iteration 463 logL: -174.43728637695312
Iteration 464 logL: -161.63243103027344
Iteration 465 logL: -140.7918701171875
Iteration 466 logL: -161.52334594726562
Iteration 467 logL: -162.0040283203125
Iteration 468 logL: -146.77978515625
Iteration 469 logL: -145.1008758544922
Iteration 470 logL: -145.02806091308594
Iteration 471 logL: -126.25932312011719
Iteration 472 logL: -134.5507049560547
Iteration 473 logL: -172.2655487060547
Iteration 474 logL: -178.7369384765625
Iteration 475 logL: -175.0013885498047
Iteration 476 logL: -155.21022033691406
Iteration 477 logL: -140.81431579589844
Iteration 478 logL: -186.590087890625
Iteration 479 logL: -147.77403259277344
Iteration 480 logL: -167.0510711669922
Iteration 481 logL: -152.2871856689453
Iteration 482 logL: -151.39828491210938
Iteration 483 logL: -129.93521118164062
Iteration 484 logL: -143.72450256347656
Iteration 485 logL: -146.91224670410156
Iteration 486 logL: -127.20298767089844
Iteration 487 logL: -144.4951934814453
Iteration 488 logL: -145.83363342285156
Iteration 489 logL: -132.7200164794922
Iteration 490 logL: -143.2606658935547
Iteration 491 logL: -144.3072509765625
Iteration 492 logL: -148.9021759033203
Iteration 493 logL: -143.65380859375
Iteration 494 logL: -141.4383544921875
Iteration 495 logL: -165.1541290283203
Iteration 496 logL: -136.8163604736328
Iteration 497 logL: -158.55410766601562
Iteration 498 logL: -138.8023681640625
Iteration 499 logL: -158.96397399902344
Iteration 500 logL: -138.33013916015625
In [12]:
for uuid, v in infr.inference_algorithm.posterior.variables.items():
    if uuid in infr.params.param_dict:
        print(v.name, infr.params[v])
In [84]:
xt = np.linspace(0,1,100)[:,None]
In [85]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
 /Users/zhenwend/mxfusion/src/MXFusion/mxfusion/core/factor_graph.py:65: UserWarning:The value N has already been assigned in the model.
 /Users/zhenwend/mxfusion/src/MXFusion/mxfusion/core/factor_graph.py:65: UserWarning:The value y has already been assigned in the model.
 /Users/zhenwend/mxfusion/src/MXFusion/mxfusion/inference/inference_parameters.py:52: UserWarning:InferenceParameters has already been initialized.  The existing one will be overwritten.
In [86]:
yt = res[m.r].asnumpy()
In [87]:
# plot(xt[:,0],yt[:,0])
yt_mean = yt.mean(0)
yt_std = yt.std(0)
#plot(xt[:,0], yt.mean(0)[:,0])
#errorbar(xt[:,0],y=yt_mean[:,0],yerr=yt_std[:,0]*2)
for i in range(yt.shape[0]):
    plot(xt[:,0],1./(1+np.exp(-yt[i,:,0])),'k',alpha=0.2)
plot(x[:,0],y[:,0],'.')
Out[87]:
[<matplotlib.lines.Line2D at 0x1a222e37f0>]
../../_images/examples_notebooks_bnn_classification_16_1.png