Bayesian Neural Network (VI) for classification (under Development)¶
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================
In [1]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference
Generate Synthetic Data¶
In [2]:
import GPy
%matplotlib inline
from pylab import *
np.random.seed(4)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(200,1)
y = np.random.multivariate_normal(mean=np.zeros((200,)), cov=k.K(x), size=(1,)).T>0.
plot(x[:,0], y[:,0], '.')
Out[2]:
[<matplotlib.lines.Line2D at 0x11cb73748>]
In [3]:
D = 10
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=1))
net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=D))
net.add(nn.Dense(2, flatten=False, in_units=D))
net.initialize(mx.init.Xavier(magnitude=1))
In [4]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
In [5]:
m = mf.Model()
m.N = mf.Variable()
m.f = mf.functions.MXFusionGluonFunction(net, num_outputs=1, broadcastable=False)
m.x = mf.Variable(shape=(m.N,1))
m.r = m.f(m.x)
for _,v in m.r.factor.parameters.items():
v.set_prior(mf.components.distributions.Normal(mean=mx.nd.array([0]),variance=mx.nd.array([3.])))
m.y = mf.distributions.Categorical.define_variable(log_prob=m.r, shape=(m.N,1), num_classes=2)
print(m)
Variable(e06bd) ~ Normal(mean=Variable(78b11), variance=Variable(961fd))
Variable(5f8c9) ~ Normal(mean=Variable(34ce9), variance=Variable(a783a))
Variable(d2e00) ~ Normal(mean=Variable(6bda0), variance=Variable(7f89e))
Variable(b44a6) ~ Normal(mean=Variable(6b08c), variance=Variable(93b41))
Variable(d3cf6) ~ Normal(mean=Variable(ad287), variance=Variable(6700e))
Variable(7ab0d) ~ Normal(mean=Variable(160fc), variance=Variable(cbaf9))
r = GluonFunctionEvaluation(nn_input_0=x, nn_dense0_weight=Variable(7ab0d), nn_dense0_bias=Variable(d3cf6), nn_dense1_weight=Variable(b44a6), nn_dense1_bias=Variable(d2e00), nn_dense2_weight=Variable(5f8c9), nn_dense2_bias=Variable(e06bd))
y ~ Categorical(log_prob=r)
In [6]:
from mxfusion.inference import BatchInferenceLoop, create_Gaussian_meanfield, GradBasedInference, StochasticVariationalInference, MAP
In [7]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=5, model=m, posterior=q, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
In [8]:
infr.initialize(y=mx.nd.array(y), x=mx.nd.array(x))
In [9]:
for v_name, v in m.r.factor.parameters.items():
uuid = v.uuid
loc_uuid = infr.inference_algorithm.posterior[uuid].factor.variance.uuid
a = infr.params.param_dict[loc_uuid].data().asnumpy()
a[:] = 1e-8
infr.params[infr.inference_algorithm.posterior[uuid].factor.mean] = net.collect_params()[v_name].data()
infr.params[infr.inference_algorithm.posterior[uuid].factor.variance] = mx.nd.array(a)
In [10]:
infr.run(max_iter=500, learning_rate=1e-1, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
Iteration 51 loss: 1066.1125488281255
Iteration 101 loss: 675.2722167968758
Iteration 151 loss: 345.30307006835945
Iteration 201 loss: 196.68641662597656
Iteration 251 loss: 155.23381042480478
Iteration 301 loss: 149.42289733886724
Iteration 351 loss: 159.71490478515625
Iteration 401 loss: 140.58926391601562
Iteration 451 loss: 174.78173828125438
Iteration 500 loss: 127.64309692382812
In [11]:
for uuid, v in infr.inference_algorithm.posterior.variables.items():
if uuid in infr.params.param_dict:
print(v.name, infr.params[v])
None
[[-4.2562084 ]
[-2.1897657 ]
[-2.7514694 ]
[-1.8618754 ]
[ 0.05935706]
[ 2.3460457 ]
[-2.6491752 ]
[-1.2179427 ]
[-0.08034295]
[-0.5979197 ]]
<NDArray 10x1 @cpu(0)>
None
[[0.01767311]
[0.00603309]
[0.00848725]
[0.05524571]
[0.68462664]
[0.00412718]
[0.03144019]
[0.11763541]
[2.0121076 ]
[0.37226263]]
<NDArray 10x1 @cpu(0)>
None
[ 1.6983228 1.4742194 1.775172 0.6392376 0.31661415 -1.6325905
1.398058 0.7429083 0.04331838 0.3991743 ]
<NDArray 10 @cpu(0)>
None
[0.00463977 0.00697836 0.00277748 0.04407028 0.45794868 0.00267045
0.01026547 0.07324851 0.6038953 0.06482685]
<NDArray 10 @cpu(0)>
None
[[-5.0367075e-01 -3.6483032e-01 -3.4889570e-01 -3.7278756e-01
-5.8295298e-01 2.0773776e-01 -5.1646495e-01 -5.6319767e-01
1.2088771e-01 -3.6126822e-02]
[ 3.2355504e-03 -3.6068845e-01 -1.8626985e-01 -1.8437026e-01
6.3100457e-04 4.4206291e-01 -2.7084729e-02 -3.1543028e-01
-2.8092265e-01 -2.6803422e-01]
[-1.4353344e-01 2.7556152e+00 2.7566373e+00 4.7164506e-01
3.9942378e-01 -3.5447137e+00 1.2198279e+00 2.0113483e-01
7.4260637e-02 1.0011230e-01]
[ 3.5851079e-01 7.9171979e-01 6.1348730e-01 5.8377886e-01
5.4714572e-01 -1.0298078e+00 3.3680087e-01 1.1881048e-02
4.9028376e-01 -1.4387065e-01]
[ 7.1333803e-02 2.1075387e-01 2.7103132e-02 -6.6015087e-02
1.6656926e-01 -3.9778087e-01 1.8710904e-01 4.3254908e-02
-1.5939955e-01 -2.0810342e-01]
[-1.2169343e-01 9.4294645e-02 3.3085659e-01 -1.9831542e-02
2.8470194e-01 -3.8632959e-01 -7.6368101e-02 1.3375407e-01
5.9273201e-01 -4.5699142e-02]
[ 1.9255243e-01 2.9560938e-01 2.5773695e-01 -1.0506964e-01
-2.5529373e-01 -3.1061968e-01 3.3579066e-02 3.4898770e-01
6.0322829e-02 2.8761932e-01]
[-7.0459640e-01 -6.9609299e-02 -3.6901351e-02 -4.2581716e-01
3.1552029e-01 -1.8861942e-01 -6.2215298e-01 -4.0387815e-01
-7.6213044e-01 1.4895415e-01]
[ 1.5128514e+00 2.8877625e-01 4.8491848e-01 4.6291590e-01
6.4278495e-01 -3.2827693e-01 8.9393836e-01 3.9123634e-01
-2.0554188e-01 1.9961188e-02]
[ 3.8393707e+00 1.5004866e+00 2.1594408e+00 1.6014071e+00
-2.5904796e-01 -1.5982518e+00 2.0201933e+00 1.7498626e-01
2.3529473e-01 2.0874260e-01]]
<NDArray 10x10 @cpu(0)>
None
[[2.708519 2.7912157 3.0016038 2.6399708 2.1374285 3.3056285
3.338311 3.417747 3.0662622 2.665338 ]
[2.8390265 2.83211 3.2312827 2.9462285 3.0979178 3.0666673
2.9610286 2.8243313 3.183116 3.1657238 ]
[0.09525167 0.5299614 0.18995798 0.10806751 0.0597569 0.66866034
0.05129567 0.24625055 0.0597211 0.13327149]
[3.2492511 3.2614036 3.3345017 3.371056 3.053195 2.24815
2.216518 3.073331 2.4673629 2.2618537 ]
[2.2471485 2.8116536 2.6036153 2.3638754 3.2123742 2.5266416
3.2636497 3.4483907 2.9033678 2.3266923 ]
[3.0316195 2.9135869 2.8787353 2.9725506 3.2761152 3.0925238
2.5114353 2.7284532 2.269938 3.5903633 ]
[3.1938636 2.8134184 3.1856582 3.7374485 2.2276769 3.2222536
2.922795 2.6769052 2.965645 2.879921 ]
[2.2802768 2.479438 3.3902857 3.9532485 2.38113 3.2590485
2.944619 2.6350875 2.6051073 3.4462888 ]
[1.8125408 2.109326 2.3153167 2.4868069 2.156431 2.2759461
2.7648149 2.4953694 2.2933164 3.1460094 ]
[0.77330697 0.21263564 0.27486432 0.37804347 0.09345496 0.18871015
1.4391748 0.44938883 0.19104742 0.3376724 ]]
<NDArray 10x10 @cpu(0)>
None
[ 0.19256005 -0.00668432 2.0988328 0.03577445 0.11586528 0.46011198
-0.0560313 0.2638263 -0.9565386 -1.5131551 ]
<NDArray 10 @cpu(0)>
None
[2.5745904 1.9814985 0.03306409 3.0929534 3.9594617 2.3214269
2.8359015 3.0308347 1.3825488 0.07675996]
<NDArray 10 @cpu(0)>
None
[[-0.03301222 0.0079538 2.3766918 0.07191022 0.01081306 -0.08140734
-0.01184111 -0.20975497 -0.10157776 -2.483351 ]
[ 0.02959426 -0.00993267 -2.3763514 -0.09460182 -0.00604838 0.04880186
0.02897908 0.20790808 0.07653924 2.4683366 ]]
<NDArray 2x10 @cpu(0)>
None
[[0.21937881 0.20881172 0.14296758 0.19930215 0.19051579 0.15455456
0.09204628 0.10818221 0.12479458 0.16975753]
[0.11583243 0.0629681 0.0659797 0.09927363 0.09680786 0.15992676
0.08846851 0.09655253 0.10265005 0.08145336]]
<NDArray 2x10 @cpu(0)>
None
[-2.461291 2.4445803]
<NDArray 2 @cpu(0)>
None
[0.09449326 0.08444152]
<NDArray 2 @cpu(0)>
In [12]:
xt = np.linspace(0,1,100)[:,None]
In [13]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
In [14]:
yt = res[m.r].asnumpy()
In [15]:
yt_mean = yt.mean(0)
yt_std = yt.std(0)
for i in range(yt.shape[0]):
plot(xt[:,0],1./(1+np.exp(yt[i,:,0]-yt[i,:,1])),'k',alpha=0.2)
plot(x[:,0],y[:,0],'.')
Out[15]:
[<matplotlib.lines.Line2D at 0x11d47e5f8>]