Interactive plot that shows why the KL Divergence is not symmetric.
forked from phil-pedruco's block: Plotting a bell (Gaussian) curve in d3
xxxxxxxxxx
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>KL Divergence</title>
<meta name="description" content="">
<script src="https://d3js.org/d3.v4.min.js"></script>
<style type="text/css">
body {
font: 10px sans-serif;
}
.axis path,
.axis line {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
.line {
fill: none;
stroke: steelblue;
stroke-width: 1.5px;
}
.ticks {
font: 10px sans-serif;
}
.track,
.track-inset,
.track-overlay {
stroke-linecap: round;
}
.track {
stroke: #000;
stroke-opacity: 0.3;
stroke-width: 10px;
}
.track-inset {
stroke: #ddd;
stroke-width: 8px;
}
.track-overlay {
pointer-events: stroke;
stroke-width: 50px;
stroke: transparent;
cursor: crosshair;
}
.handle {
fill: #fff;
stroke: #000;
stroke-opacity: 0.5;
stroke-width: 1.25px;
}
</style>
</head>
<body>
</body>
<script type="text/javascript">
//setting up empty data array
var data = [];
//generate points in gaussian
function getData(q_mean = 0) {
var numPoints = 1000,
a = -10,
b = 10,
delta = (b - a)/(numPoints-1);
q_mean = 0,
q_std = 1,
p_mean_dist = 5,
p_std = 1;
for (var i = 0; i < numPoints; i++) {
x = a + i*delta;
p = gaussian(x, -p_mean_dist, p_std) + gaussian(x, p_mean_dist, p_std);
q = gaussian(x, q_mean, q_std);
el = {
"x": x,
"p": p,
"q": q
};
data.push(el);
}
}
getData();
// line chart based on https://bl.ocks.org/mbostock/3883245
var margin = {
top: 20,
right: 20,
bottom: 30,
left: 50
},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var x = d3.scaleLinear()
.range([0, width])
.clamp(true);
var y = d3.scaleLinear()
.range([height, 0]);
x.domain(d3.extent(data, function(d) {
return d.x;
}));
y.domain(d3.extent(data, function(d) {
return d.p;
}));
var xAxis = d3.axisBottom()
.scale(x);
var yAxis = d3.axisLeft()
.scale(y);
var line1 = d3.line()
.x(function(d) {
return x(d.x);
})
.y(function(d) {
return y(d.p);
});
var line2 = d3.line()
.x(function(d) {
return x(d.x);
})
.y(function(d) {
return y(d.q);
});
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var slider = svg.append("g")
.attr("class", "slider")
.attr("transform", "translate(" + (margin.left-20) + "," + margin.top + ")");
slider.append("line")
.attr("class", "track")
.attr("x1", 0)
.attr("x2", 159)
.select(function() { return this.parentNode.appendChild(this.cloneNode(true)); })
.attr("class", "track-inset")
.select(function() { return this.parentNode.appendChild(this.cloneNode(true)); })
.attr("class", "track-overlay")
.call(d3.drag()
.on("start.interrupt", function() { slider.interrupt(); })
.on("start drag", function() {
q_mean = x.invert(d3.event.x);
handle.attr("cx", x(q_mean));
getData(q_mean);
}));
var handle = slider.insert("circle", ".track-overlay")
.attr("class", "handle")
.attr("r", 9);
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis);
svg.append("g")
.attr("class", "y axis")
.call(yAxis);
svg.append("path")
.datum(data)
.attr("class", "line")
.attr("d", line1);
svg.append("path")
.datum(data)
.attr("class", "line")
.style("stroke", "red")
.attr("d", line2);
//calculate KL divergences
var KL_pq = data.reduce(function(kl, item){
return kl + item.p * Math.log(item.p/item.q);
}, 0);
var KL_qp = data.reduce(function(kl, item){
return kl + item.q * Math.log(item.q/item.p);
}, 0);
var myText = svg.append("text")
.attr("x", 0.8*width)
.attr("y", 2)
.attr('text-anchor', 'left')
.attr("class", "myLabel")//easy to style with CSS
.attr("font-size", "20px")
.text("KL(p,q) = " + KL_pq);
var myText = svg.append("text")
.attr("x", 0.8*width)
.attr("y", 25)
.attr('text-anchor', 'left')
.attr("class", "myLabel")//easy to style with CSS
.attr("font-size", "20px")
.text("KL(q,p) = " + KL_qp);
//taken from Jason Davies science library
// https://github.com/jasondavies/science.js/
function gaussian(x, mean, std) {
var gaussianConstant = 1 / Math.sqrt(2 * Math.PI);
x = (x - mean) / std;
return gaussianConstant * Math.exp(-.5 * x * x) / std;
};
</script>
</html>
https://d3js.org/d3.v4.min.js