Bayesian Neural Network (VI) for classification (under Development)

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================
In [1]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference

Generate Synthetic Data

In [2]:
import GPy
%matplotlib inline
from pylab import *

np.random.seed(4)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(200,1)
y = np.random.multivariate_normal(mean=np.zeros((200,)), cov=k.K(x), size=(1,)).T>0.
plot(x[:,0], y[:,0], '.')
Out[2]:
[<matplotlib.lines.Line2D at 0x11cb73748>]
../../_images/examples_notebooks_bnn_classification_4_1.png
In [3]:
D = 10
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=1))
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=D))
    net.add(nn.Dense(2, flatten=False, in_units=D))
net.initialize(mx.init.Xavier(magnitude=1))
In [4]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
In [5]:
m = mf.Model()
m.N = mf.Variable()
m.f = mf.functions.MXFusionGluonFunction(net, num_outputs=1, broadcastable=False)
m.x = mf.Variable(shape=(m.N,1))
m.r = m.f(m.x)
for _,v in m.r.factor.parameters.items():
    v.set_prior(mf.components.distributions.Normal(mean=mx.nd.array([0]),variance=mx.nd.array([3.])))
m.y = mf.distributions.Categorical.define_variable(log_prob=m.r, shape=(m.N,1), num_classes=2)
print(m)
Variable(e06bd) ~ Normal(mean=Variable(78b11), variance=Variable(961fd))
Variable(5f8c9) ~ Normal(mean=Variable(34ce9), variance=Variable(a783a))
Variable(d2e00) ~ Normal(mean=Variable(6bda0), variance=Variable(7f89e))
Variable(b44a6) ~ Normal(mean=Variable(6b08c), variance=Variable(93b41))
Variable(d3cf6) ~ Normal(mean=Variable(ad287), variance=Variable(6700e))
Variable(7ab0d) ~ Normal(mean=Variable(160fc), variance=Variable(cbaf9))
r = GluonFunctionEvaluation(nn_input_0=x, nn_dense0_weight=Variable(7ab0d), nn_dense0_bias=Variable(d3cf6), nn_dense1_weight=Variable(b44a6), nn_dense1_bias=Variable(d2e00), nn_dense2_weight=Variable(5f8c9), nn_dense2_bias=Variable(e06bd))
y ~ Categorical(log_prob=r)
In [6]:
from mxfusion.inference import BatchInferenceLoop, create_Gaussian_meanfield, GradBasedInference, StochasticVariationalInference, MAP
In [7]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=5, model=m, posterior=q, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
In [8]:
infr.initialize(y=mx.nd.array(y), x=mx.nd.array(x))
In [9]:
for v_name, v in m.r.factor.parameters.items():
    uuid = v.uuid
    loc_uuid = infr.inference_algorithm.posterior[uuid].factor.variance.uuid
    a = infr.params.param_dict[loc_uuid].data().asnumpy()
    a[:] = 1e-8
    infr.params[infr.inference_algorithm.posterior[uuid].factor.mean] = net.collect_params()[v_name].data()
    infr.params[infr.inference_algorithm.posterior[uuid].factor.variance] = mx.nd.array(a)
In [10]:
infr.run(max_iter=500, learning_rate=1e-1, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
Iteration 51 loss: 1066.1125488281255
Iteration 101 loss: 675.2722167968758
Iteration 151 loss: 345.30307006835945
Iteration 201 loss: 196.68641662597656
Iteration 251 loss: 155.23381042480478
Iteration 301 loss: 149.42289733886724
Iteration 351 loss: 159.71490478515625
Iteration 401 loss: 140.58926391601562
Iteration 451 loss: 174.78173828125438
Iteration 500 loss: 127.64309692382812
In [11]:
for uuid, v in infr.inference_algorithm.posterior.variables.items():
    if uuid in infr.params.param_dict:
        print(v.name, infr.params[v])
None
[[-4.2562084 ]
 [-2.1897657 ]
 [-2.7514694 ]
 [-1.8618754 ]
 [ 0.05935706]
 [ 2.3460457 ]
 [-2.6491752 ]
 [-1.2179427 ]
 [-0.08034295]
 [-0.5979197 ]]
<NDArray 10x1 @cpu(0)>
None
[[0.01767311]
 [0.00603309]
 [0.00848725]
 [0.05524571]
 [0.68462664]
 [0.00412718]
 [0.03144019]
 [0.11763541]
 [2.0121076 ]
 [0.37226263]]
<NDArray 10x1 @cpu(0)>
None
[ 1.6983228   1.4742194   1.775172    0.6392376   0.31661415 -1.6325905
  1.398058    0.7429083   0.04331838  0.3991743 ]
<NDArray 10 @cpu(0)>
None
[0.00463977 0.00697836 0.00277748 0.04407028 0.45794868 0.00267045
 0.01026547 0.07324851 0.6038953  0.06482685]
<NDArray 10 @cpu(0)>
None
[[-5.0367075e-01 -3.6483032e-01 -3.4889570e-01 -3.7278756e-01
  -5.8295298e-01  2.0773776e-01 -5.1646495e-01 -5.6319767e-01
   1.2088771e-01 -3.6126822e-02]
 [ 3.2355504e-03 -3.6068845e-01 -1.8626985e-01 -1.8437026e-01
   6.3100457e-04  4.4206291e-01 -2.7084729e-02 -3.1543028e-01
  -2.8092265e-01 -2.6803422e-01]
 [-1.4353344e-01  2.7556152e+00  2.7566373e+00  4.7164506e-01
   3.9942378e-01 -3.5447137e+00  1.2198279e+00  2.0113483e-01
   7.4260637e-02  1.0011230e-01]
 [ 3.5851079e-01  7.9171979e-01  6.1348730e-01  5.8377886e-01
   5.4714572e-01 -1.0298078e+00  3.3680087e-01  1.1881048e-02
   4.9028376e-01 -1.4387065e-01]
 [ 7.1333803e-02  2.1075387e-01  2.7103132e-02 -6.6015087e-02
   1.6656926e-01 -3.9778087e-01  1.8710904e-01  4.3254908e-02
  -1.5939955e-01 -2.0810342e-01]
 [-1.2169343e-01  9.4294645e-02  3.3085659e-01 -1.9831542e-02
   2.8470194e-01 -3.8632959e-01 -7.6368101e-02  1.3375407e-01
   5.9273201e-01 -4.5699142e-02]
 [ 1.9255243e-01  2.9560938e-01  2.5773695e-01 -1.0506964e-01
  -2.5529373e-01 -3.1061968e-01  3.3579066e-02  3.4898770e-01
   6.0322829e-02  2.8761932e-01]
 [-7.0459640e-01 -6.9609299e-02 -3.6901351e-02 -4.2581716e-01
   3.1552029e-01 -1.8861942e-01 -6.2215298e-01 -4.0387815e-01
  -7.6213044e-01  1.4895415e-01]
 [ 1.5128514e+00  2.8877625e-01  4.8491848e-01  4.6291590e-01
   6.4278495e-01 -3.2827693e-01  8.9393836e-01  3.9123634e-01
  -2.0554188e-01  1.9961188e-02]
 [ 3.8393707e+00  1.5004866e+00  2.1594408e+00  1.6014071e+00
  -2.5904796e-01 -1.5982518e+00  2.0201933e+00  1.7498626e-01
   2.3529473e-01  2.0874260e-01]]
<NDArray 10x10 @cpu(0)>
None
[[2.708519   2.7912157  3.0016038  2.6399708  2.1374285  3.3056285
  3.338311   3.417747   3.0662622  2.665338  ]
 [2.8390265  2.83211    3.2312827  2.9462285  3.0979178  3.0666673
  2.9610286  2.8243313  3.183116   3.1657238 ]
 [0.09525167 0.5299614  0.18995798 0.10806751 0.0597569  0.66866034
  0.05129567 0.24625055 0.0597211  0.13327149]
 [3.2492511  3.2614036  3.3345017  3.371056   3.053195   2.24815
  2.216518   3.073331   2.4673629  2.2618537 ]
 [2.2471485  2.8116536  2.6036153  2.3638754  3.2123742  2.5266416
  3.2636497  3.4483907  2.9033678  2.3266923 ]
 [3.0316195  2.9135869  2.8787353  2.9725506  3.2761152  3.0925238
  2.5114353  2.7284532  2.269938   3.5903633 ]
 [3.1938636  2.8134184  3.1856582  3.7374485  2.2276769  3.2222536
  2.922795   2.6769052  2.965645   2.879921  ]
 [2.2802768  2.479438   3.3902857  3.9532485  2.38113    3.2590485
  2.944619   2.6350875  2.6051073  3.4462888 ]
 [1.8125408  2.109326   2.3153167  2.4868069  2.156431   2.2759461
  2.7648149  2.4953694  2.2933164  3.1460094 ]
 [0.77330697 0.21263564 0.27486432 0.37804347 0.09345496 0.18871015
  1.4391748  0.44938883 0.19104742 0.3376724 ]]
<NDArray 10x10 @cpu(0)>
None
[ 0.19256005 -0.00668432  2.0988328   0.03577445  0.11586528  0.46011198
 -0.0560313   0.2638263  -0.9565386  -1.5131551 ]
<NDArray 10 @cpu(0)>
None
[2.5745904  1.9814985  0.03306409 3.0929534  3.9594617  2.3214269
 2.8359015  3.0308347  1.3825488  0.07675996]
<NDArray 10 @cpu(0)>
None
[[-0.03301222  0.0079538   2.3766918   0.07191022  0.01081306 -0.08140734
  -0.01184111 -0.20975497 -0.10157776 -2.483351  ]
 [ 0.02959426 -0.00993267 -2.3763514  -0.09460182 -0.00604838  0.04880186
   0.02897908  0.20790808  0.07653924  2.4683366 ]]
<NDArray 2x10 @cpu(0)>
None
[[0.21937881 0.20881172 0.14296758 0.19930215 0.19051579 0.15455456
  0.09204628 0.10818221 0.12479458 0.16975753]
 [0.11583243 0.0629681  0.0659797  0.09927363 0.09680786 0.15992676
  0.08846851 0.09655253 0.10265005 0.08145336]]
<NDArray 2x10 @cpu(0)>
None
[-2.461291   2.4445803]
<NDArray 2 @cpu(0)>
None
[0.09449326 0.08444152]
<NDArray 2 @cpu(0)>
In [12]:
xt = np.linspace(0,1,100)[:,None]
In [13]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
In [14]:
yt = res[m.r].asnumpy()
In [15]:
yt_mean = yt.mean(0)
yt_std = yt.std(0)
for i in range(yt.shape[0]):
    plot(xt[:,0],1./(1+np.exp(yt[i,:,0]-yt[i,:,1])),'k',alpha=0.2)
plot(x[:,0],y[:,0],'.')
Out[15]:
[<matplotlib.lines.Line2D at 0x11d47e5f8>]
../../_images/examples_notebooks_bnn_classification_17_1.png