Skip to content

Commit e586754

Browse files
committed
More on residuals and errors!
1 parent 3957928 commit e586754

File tree

6 files changed

+434
-39
lines changed

6 files changed

+434
-39
lines changed

extras/animations/02-machine_learning_fundamentals.py

Lines changed: 335 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def construct(self):
7878
)
7979

8080
best_params = MathTex(
81-
f"\\theta_0={theta_0_best:.2f}, \\theta_1={theta_1_best:.2f}",
81+
f"\\theta_0^*={theta_0_best:.2f}, \\theta_1^*={theta_1_best:.2f}",
8282
font_size=36,
8383
color=GREEN
8484
)
@@ -96,6 +96,339 @@ def construct(self):
9696
self.wait(1)
9797

9898

99+
class QuadraticRegressionOverUnderfit(Scene):
100+
def construct(self):
101+
# Create axes
102+
axes = Axes(
103+
x_range=[-3, 3, 1],
104+
y_range=[-2, 8, 2],
105+
x_length=7,
106+
y_length=5,
107+
axis_config={"color": WHITE},
108+
tips=False
109+
)
110+
111+
# Add labels
112+
x_label = axes.get_x_axis_label("x", edge=RIGHT, direction=RIGHT)
113+
y_label = axes.get_y_axis_label("y", edge=UP, direction=UP)
114+
115+
axes_group = VGroup(axes, x_label, y_label)
116+
axes_group.shift(RIGHT * 1.5)
117+
axes_group.shift(DOWN * 1)
118+
119+
120+
self.play(Create(axes), Write(x_label), Write(y_label))
121+
self.wait(1)
122+
123+
# Generate sample data points (quadratic with noise)
124+
np.random.seed(42)
125+
x_data = np.array([-2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5])
126+
theta_true = np.array([1, 0, 0.5])
127+
y_true = theta_true[0] + theta_true[1] * x_data + theta_true[2] * x_data**2 # True quadratic function
128+
y_data = y_true + np.random.normal(0, 0.6, len(x_data))
129+
130+
# Draw the real distribution
131+
real_distribution = axes.plot(
132+
lambda x: theta_true[0] + theta_true[1] * x + theta_true[2] * x**2,
133+
color=WHITE,
134+
x_range=(-3, 3)
135+
)
136+
137+
# 0. REAL DISTRIBUTION
138+
real_label = Text("Real Distribution", font_size=28, color=WHITE)
139+
real_label.to_edge(LEFT).shift(UP * 1.5)
140+
141+
definition_real = MathTex(r"f_{\boldsymbol{\theta}} \in \mathcal{F}_2", font_size=28)
142+
definition_real.next_to(real_label, DOWN, aligned_edge=LEFT)
143+
144+
equation_real = MathTex(r"y = \theta_0 + \theta_1 x + \theta_2 x^2 + \epsilon", font_size=28)
145+
equation_real.next_to(definition_real, DOWN, aligned_edge=LEFT)
146+
147+
self.play(Create(real_distribution), Write(real_label))
148+
self.play(Write(definition_real), Write(equation_real))
149+
150+
# Create dots for data points
151+
dots = VGroup()
152+
for x, y in zip(x_data, y_data):
153+
dot = Dot(axes.c2p(x, y), color=WHITE, radius=0.08)
154+
dots.add(dot)
155+
156+
# Animate data points appearing
157+
self.play(LaggedStart(*[GrowFromCenter(dot) for dot in dots], lag_ratio=0.2))
158+
self.wait(1)
159+
160+
# Draw Errors from Noise
161+
error_lines = VGroup()
162+
for x, y, y_t in zip(x_data, y_data, y_true):
163+
# Draw vertical line from true value to observed value
164+
start_point = axes.c2p(x, y_t)
165+
end_point = axes.c2p(x, y)
166+
error_line = Line(start_point, end_point, color=YELLOW, stroke_width=2)
167+
error_lines.add(error_line)
168+
169+
# Add label for errors
170+
error_label = MathTex(r"\epsilon \sim \mathcal{N}(0, \sigma^2)", font_size=28, color=YELLOW)
171+
error_label.next_to(equation_real, DOWN, aligned_edge=LEFT)
172+
173+
self.play(LaggedStart(*[Create(line) for line in error_lines], lag_ratio=0.1))
174+
self.play(Write(error_label))
175+
self.wait(2)
176+
177+
# Fade out error visualization
178+
self.play(FadeOut(error_lines), FadeOut(error_label))
179+
self.wait(0.5)
180+
181+
# Title
182+
title = Text("Underfitting vs Good Fit vs Overfitting", font_size=32)
183+
title.to_edge(UP)
184+
self.play(Write(title))
185+
self.wait(0.5)
186+
187+
self.play(FadeOut(real_distribution), FadeOut(real_label), FadeOut(definition_real), FadeOut(equation_real))
188+
189+
# 1. UNDERFITTING - Linear model (too simple)
190+
underfit_label = Text("Underfitting (Too Simple)", font_size=28, color=RED)
191+
underfit_label.to_edge(LEFT).shift(UP * 1.5)
192+
193+
definition_underfit = MathTex(r"f_{\boldsymbol{\theta}} \in \mathcal{F}_1", font_size=32)
194+
definition_underfit.next_to(underfit_label, DOWN, aligned_edge=LEFT)
195+
196+
equation_underfit = MathTex(r"\hat{y} = \theta_0 + \theta_1 x", font_size=32)
197+
equation_underfit.next_to(definition_underfit, DOWN, aligned_edge=LEFT)
198+
199+
self.play(Write(underfit_label), Write(definition_underfit), Write(equation_underfit))
200+
201+
# Fit a linear model (underfitting)
202+
A_linear = np.vstack([np.ones(len(x_data)), x_data]).T
203+
theta_linear = np.linalg.lstsq(A_linear, y_data, rcond=None)[0]
204+
205+
underfit_line = axes.plot(
206+
lambda x: theta_linear[0] + theta_linear[1] * x,
207+
color=RED,
208+
x_range=[-3, 3]
209+
)
210+
211+
params_underfit = MathTex(
212+
f"\\theta_0^*={theta_linear[0]:.2f}, \\theta_1^*={theta_linear[1]:.2f}",
213+
font_size=28,
214+
color=RED
215+
)
216+
params_underfit.next_to(equation_underfit, DOWN, aligned_edge=LEFT)
217+
218+
self.play(Create(underfit_line), Write(params_underfit))
219+
self.wait(2)
220+
221+
# Fade out underfitting
222+
self.play(
223+
FadeOut(underfit_line),
224+
FadeOut(underfit_label),
225+
FadeOut(definition_underfit),
226+
FadeOut(equation_underfit),
227+
FadeOut(params_underfit)
228+
)
229+
self.wait(0.5)
230+
231+
# 2. GOOD FIT - Quadratic model (just right)
232+
goodfit_label = Text("Good Fit (Just Right)", font_size=28, color=GREEN)
233+
goodfit_label.to_edge(LEFT).shift(UP * 1.5)
234+
235+
definition_goodfit = MathTex(r"f_{\boldsymbol{\theta}} \in \mathcal{F}_2", font_size=32)
236+
definition_goodfit.next_to(goodfit_label, DOWN, aligned_edge=LEFT)
237+
238+
equation_goodfit = MathTex(r"\hat{y} = \theta_0 + \theta_1 x + \theta_2 x^2", font_size=32)
239+
equation_goodfit.next_to(definition_goodfit, DOWN, aligned_edge=LEFT)
240+
241+
self.play(Write(goodfit_label), Write(definition_goodfit), Write(equation_goodfit))
242+
243+
# Fit a quadratic model (good fit)
244+
A_quad = np.vstack([np.ones(len(x_data)), x_data, x_data**2]).T
245+
theta_quad = np.linalg.lstsq(A_quad, y_data, rcond=None)[0]
246+
247+
goodfit_curve = axes.plot(
248+
lambda x: theta_quad[0] + theta_quad[1] * x + theta_quad[2] * x**2,
249+
color=GREEN,
250+
x_range=[-3, 3]
251+
)
252+
253+
params_goodfit = MathTex(
254+
f"\\theta_0^*={theta_quad[0]:.2f}, \\theta_1^*={theta_quad[1]:.2f}, \\theta_2^*={theta_quad[2]:.2f}",
255+
font_size=28,
256+
color=GREEN
257+
)
258+
params_goodfit.next_to(equation_goodfit, DOWN, aligned_edge=LEFT)
259+
260+
self.play(Create(goodfit_curve), Write(params_goodfit))
261+
self.wait(2)
262+
263+
# Fade out good fit
264+
self.play(
265+
FadeOut(goodfit_curve),
266+
FadeOut(goodfit_label),
267+
FadeOut(definition_goodfit),
268+
FadeOut(equation_goodfit),
269+
FadeOut(params_goodfit)
270+
)
271+
self.wait(0.5)
272+
273+
# 3. OVERFITTING - High degree polynomial (too complex)
274+
overfit_label = Text("Overfitting (Too Complex)", font_size=28, color=ORANGE)
275+
overfit_label.to_edge(LEFT).shift(UP * 2)
276+
277+
definition_overfit = MathTex(r"f_{\boldsymbol{\theta}} \in \mathcal{F}_{10}", font_size=32)
278+
definition_overfit.next_to(overfit_label, DOWN, aligned_edge=LEFT)
279+
280+
equation_overfit = MathTex(
281+
r"\hat{y} = \theta_0 + \theta_1 x + \cdots + \theta_{10} x^{10}",
282+
font_size=32
283+
)
284+
equation_overfit.next_to(definition_overfit, DOWN, aligned_edge=LEFT)
285+
286+
self.play(Write(overfit_label), Write(definition_overfit), Write(equation_overfit))
287+
288+
# Fit a 10th degree polynomial (overfitting)
289+
degree = 10
290+
A_poly = np.vstack([x_data**i for i in range(degree + 1)]).T
291+
theta_poly = np.linalg.lstsq(A_poly, y_data, rcond=None)[0]
292+
293+
def poly_func(x):
294+
return sum(theta_poly[i] * x**i for i in range(degree + 1))
295+
296+
overfit_curve = axes.plot(
297+
poly_func,
298+
color=ORANGE,
299+
x_range=[-2.5, 2.5],
300+
use_smoothing=True
301+
)
302+
303+
self.play(Create(overfit_curve))
304+
self.wait(2)
305+
306+
# Show all three together for comparison
307+
comparison_label = Text("Comparison", font_size=32, color=YELLOW)
308+
comparison_label.next_to(title, DOWN)
309+
310+
self.play(
311+
FadeOut(overfit_label),
312+
FadeOut(definition_overfit),
313+
FadeOut(equation_overfit),
314+
Write(comparison_label)
315+
)
316+
317+
# Recreate all three curves
318+
underfit_line_final = axes.plot(
319+
lambda x: theta_linear[0] + theta_linear[1] * x,
320+
color=RED,
321+
x_range=[-3, 3],
322+
stroke_width=3
323+
)
324+
325+
goodfit_curve_final = axes.plot(
326+
lambda x: theta_quad[0] + theta_quad[1] * x + theta_quad[2] * x**2,
327+
color=GREEN,
328+
x_range=[-3, 3],
329+
stroke_width=3
330+
)
331+
332+
# Add the real distribution back
333+
real_distribution_final = axes.plot(
334+
lambda x: theta_true[0] + theta_true[1] * x + theta_true[2] * x**2,
335+
color=WHITE,
336+
x_range=[-3, 3],
337+
stroke_width=4,
338+
stroke_opacity=0.7
339+
)
340+
341+
self.play(
342+
Create(underfit_line_final),
343+
Create(goodfit_curve_final),
344+
Create(real_distribution_final)
345+
)
346+
347+
# Add legend
348+
legend_under = VGroup(
349+
Line(ORIGIN, RIGHT * 0.5, color=RED, stroke_width=3),
350+
Text("Underfit", font_size=20, color=RED)
351+
).arrange(RIGHT, buff=0.2)
352+
353+
legend_good = VGroup(
354+
Line(ORIGIN, RIGHT * 0.5, color=GREEN, stroke_width=3),
355+
Text("Good Fit", font_size=20, color=GREEN)
356+
).arrange(RIGHT, buff=0.2)
357+
358+
legend_over = VGroup(
359+
Line(ORIGIN, RIGHT * 0.5, color=ORANGE, stroke_width=3),
360+
Text("Overfit", font_size=20, color=ORANGE)
361+
).arrange(RIGHT, buff=0.2)
362+
363+
legend_real = VGroup(
364+
Line(ORIGIN, RIGHT * 0.5, color=WHITE, stroke_width=4, stroke_opacity=0.7),
365+
Text("Real Distribution", font_size=20, color=WHITE)
366+
).arrange(RIGHT, buff=0.2)
367+
368+
legend = VGroup(legend_real, legend_under, legend_good, legend_over).arrange(
369+
DOWN, aligned_edge=LEFT, buff=0.2
370+
)
371+
legend.to_edge(LEFT).shift(DOWN * 1.5)
372+
373+
self.play(FadeIn(legend))
374+
self.wait(3)
375+
376+
# Error Analysis Section
377+
self.play(
378+
FadeOut(comparison_label),
379+
FadeOut(legend)
380+
)
381+
382+
error_title = Text("Error Analysis: Residuals vs Noise", font_size=32, color=YELLOW)
383+
error_title.next_to(title, DOWN)
384+
self.play(Write(error_title))
385+
self.wait(1)
386+
387+
# Calculate errors for each model
388+
y_pred_linear = theta_linear[0] + theta_linear[1] * x_data
389+
y_pred_quad = theta_quad[0] + theta_quad[1] * x_data + theta_quad[2] * x_data**2
390+
y_pred_poly = np.array([poly_func(x) for x in x_data])
391+
392+
# Calculate MSE for each model and noise
393+
mse_noise = np.mean((y_data - y_true)**2)
394+
mse_underfit = np.mean((y_data - y_pred_linear)**2)
395+
mse_goodfit = np.mean((y_data - y_pred_quad)**2)
396+
mse_overfit = np.mean((y_data - y_pred_poly)**2)
397+
398+
# Create error comparison table
399+
error_labels = VGroup(
400+
MathTex(r"\lVert\epsilon\rVert_2^2", font_size=28, color=YELLOW),
401+
MathTex(r"\lVert r_{\text{underfit}}\rVert_2^2", font_size=28, color=RED),
402+
MathTex(r"\lVert r_{\text{good fit}}\rVert_2^2", font_size=28, color=GREEN),
403+
MathTex(r"\lVert r_{\text{overfit}}\rVert_2^2", font_size=28, color=ORANGE)
404+
).arrange(DOWN, aligned_edge=LEFT, buff=0.3)
405+
406+
error_values = VGroup(
407+
MathTex(f" = {mse_noise:.3f}", font_size=28, color=YELLOW),
408+
MathTex(f" = {mse_underfit:.3f}", font_size=28, color=RED),
409+
MathTex(f" = {mse_goodfit:.3f}", font_size=28, color=GREEN),
410+
MathTex(f" = {mse_overfit:.3f}", font_size=28, color=ORANGE)
411+
).arrange(DOWN, aligned_edge=LEFT, buff=0.3)
412+
413+
# Align the values with their corresponding labels
414+
for i, (label, value) in enumerate(zip(error_labels, error_values)):
415+
value.align_to(label, UP)
416+
417+
# Position labels and values
418+
error_labels.to_edge(LEFT).shift(UP * 0.5)
419+
error_values.next_to(error_labels, RIGHT, buff=0.5)
420+
421+
self.play(
422+
Write(error_labels),
423+
Write(error_values)
424+
)
425+
self.wait(4)
426+
427+
# Fade out everything
428+
self.play(*[FadeOut(mob) for mob in self.mobjects])
429+
self.wait(1)
430+
431+
99432
class BinaryClassificationSimple(Scene):
100433
def construct(self):
101434

@@ -230,7 +563,7 @@ def construct(self):
230563
)
231564

232565
best_params = MathTex(
233-
f"\\theta_0={theta_0_best:.1f}, \\theta_1={theta_1_best:.1f}, \\theta_2={theta_2_best:.1f}",
566+
f"\\theta_0^*={theta_0_best:.1f}, \\theta_1^*={theta_1_best:.1f}, \\theta_2^*={theta_2_best:.1f}",
234567
font_size=32,
235568
color=GREEN
236569
)

0 commit comments

Comments
 (0)