Nipun commited on
Commit
14a876d
·
1 Parent(s): b5978b3
Files changed (2) hide show
  1. app.py +65 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+
6
+ st.set_page_config(layout="wide")
7
+
8
+ st.title("Bivariate Normal Distribution")
9
+
10
+ # Sidebar for controls
11
+ with st.sidebar:
12
+ st.header("Controls")
13
+ mu_x = st.slider("Mean of X (μx)", -2.0, 2.0, 0.0)
14
+ mu_y = st.slider("Mean of Y (μy)", -2.0, 2.0, 0.0)
15
+ sigma_x = st.slider("Std Dev of X (σx)", 0.1, 2.0, 1.0)
16
+ sigma_y = st.slider("Std Dev of Y (σy)", 0.1, 2.0, 1.0)
17
+ rho = st.slider("Correlation (ρ)", -0.9, 0.9, 0.0)
18
+
19
+ # Covariance matrix
20
+ cov_matrix = torch.tensor([[sigma_x**2, rho * sigma_x * sigma_y],
21
+ [rho * sigma_x * sigma_y, sigma_y**2]])
22
+ mean_vector = torch.tensor([mu_x, mu_y])
23
+
24
+ # Create distribution
25
+ distribution = torch.distributions.MultivariateNormal(mean_vector, cov_matrix)
26
+
27
+ # Generate grid
28
+ x = torch.linspace(-4, 4, 100)
29
+ y = torch.linspace(-4, 4, 100)
30
+ X, Y = torch.meshgrid(x, y, indexing='xy')
31
+ pos = torch.stack((X, Y), dim=-1)
32
+
33
+ Z = torch.exp(distribution.log_prob(pos))
34
+
35
+ # Compute marginal distributions
36
+ marginal_x = torch.distributions.Normal(mean_vector[0], torch.sqrt(cov_matrix[0, 0]))
37
+ marginal_y = torch.distributions.Normal(mean_vector[1], torch.sqrt(cov_matrix[1, 1]))
38
+
39
+ pdf_x = torch.exp(marginal_x.log_prob(x))
40
+ pdf_y = torch.exp(marginal_y.log_prob(y))
41
+
42
+ # Convert to numpy for plotting
43
+ X, Y, Z = X.numpy(), Y.numpy(), Z.numpy()
44
+
45
+ # Create 3D surface plot
46
+ fig = go.Figure()
47
+ fig.add_trace(go.Surface(z=Z, x=X, y=Y, colorscale='Viridis', opacity=0.9, name='Density'))
48
+
49
+ # Marginal distributions on the walls
50
+ fig.add_trace(go.Scatter3d(x=x.numpy(), y=np.full_like(x.numpy(), -4), z=pdf_x.numpy() / np.max(pdf_x.numpy()) * np.max(Z), mode='lines', line=dict(color='red', width=4), name='Marginal X'))
51
+ fig.add_trace(go.Scatter3d(x=np.full_like(y.numpy(), 4), y=y.numpy(), z=pdf_y.numpy() / np.max(pdf_y.numpy()) * np.max(Z), mode='lines', line=dict(color='blue', width=4), name='Marginal Y'))
52
+
53
+ fig.update_layout(
54
+ scene=dict(
55
+ xaxis_title='X',
56
+ yaxis_title='Y',
57
+ zaxis_title='Density',
58
+ ),
59
+ margin=dict(l=0, r=0, t=20, b=20),
60
+ legend=dict(x=0.8, y=0.9, font=dict(size=14)),
61
+ width=1100, height=800
62
+ )
63
+
64
+ # Main display
65
+ st.plotly_chart(fig, use_container_width=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ matplotlib
3
+ torch
4
+ plotly