-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathglobe.py
More file actions
103 lines (84 loc) · 3.05 KB
/
globe.py
File metadata and controls
103 lines (84 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import pandas as pd
import plotly.express as px
from dash import Dash, dcc, html, Input, Output
import json
import numpy as np
# Read the dataset
df = pd.read_csv('archive/2010_2021_HS2_export.csv')
# Load geojson data
with open('countries.geo.json', 'r') as f:
geojson_data = json.load(f)
# Create a Dash app
app = Dash(__name__)
# Define the layout
app.layout = html.Div([
html.H1('Commodity Trade Analysis by Country'),
dcc.Dropdown(
id='commodity-dropdown',
options=[{'label': commodity, 'value': commodity} for commodity in df['Commodity'].unique()],
value=df['Commodity'].unique()[0],
clearable=False,
placeholder='Select a commodity'
),
dcc.Graph(id='trade-heatmap')
])
# Define callback to update heatmap
@app.callback(
Output('trade-heatmap', 'figure'),
[Input('commodity-dropdown', 'value')]
)
def update_heatmap(selected_commodity):
# Filter data for the selected commodity
commodity_data = df[df['Commodity'] == selected_commodity]
# Group by country and sum total trade value
country_total_trade = commodity_data.groupby('country')['value'].sum().reset_index()
country_total_trade['normalized_value'] = country_total_trade['value'] / country_total_trade['value'].max()
colors = px.colors.sequential.Oranges
# df["Density"] = np.log1p(df["2020"])
edges = pd.cut(country_total_trade["normalized_value"], bins=len(colors)-1, retbins=True)[1]
edges = edges[:-1] / edges[-1]
# color scales don't like negative edges...
edges = np.maximum(edges, np.full(len(edges), 0))
cc_scale = (
[(0, colors[0])]
+ [(e, colors[(i + 1) // 2]) for i, e in enumerate(np.repeat(edges,2))]
+ [(1, colors[-1])]
)
ticks = np.linspace(country_total_trade["normalized_value"].min(), country_total_trade["normalized_value"].max(), len(colors))[1:-1]
# Create heatmap on a geo layout
fig = px.choropleth(
country_total_trade,
locations="country",
locationmode='country names',
color="normalized_value", # Color represents the summed trade value
hover_name="country",
hover_data={"value": ":.2f"}, # Display value with two decimals on hover
title='Total Trade Analysis for ' + selected_commodity,
geojson=geojson_data,
color_continuous_scale=cc_scale, # Using Viridis color scale
# range_color=(0, 1),
)
fig.update_layout(
coloraxis={
"colorbar": {
"tickmode": "array",
"tickvals": ticks,
"ticktext": np.expm1(ticks).round(3),
}
},
height = 500,
plot_bgcolor='rgba(0,0,0,0)'
)
fig.update_geos(
fitbounds='locations',
showcountries=True,
projection_type="orthographic", # Set projection to 'orthographic' for a globe view
showocean=True,
oceancolor="LightBlue"
)
# fig.update_layout(
# # Set background color to transparent
# )
return fig
if __name__ == '__main__':
app.run_server(debug=True)