@@ -78,7 +78,7 @@ def construct(self):
7878 )
7979
8080 best_params = MathTex (
81- f"\\ theta_0={ theta_0_best :.2f} , \\ theta_1={ theta_1_best :.2f} " ,
81+ f"\\ theta_0^* ={ theta_0_best :.2f} , \\ theta_1^* ={ theta_1_best :.2f} " ,
8282 font_size = 36 ,
8383 color = GREEN
8484 )
@@ -96,6 +96,339 @@ def construct(self):
9696 self .wait (1 )
9797
9898
99+ class QuadraticRegressionOverUnderfit (Scene ):
100+ def construct (self ):
101+ # Create axes
102+ axes = Axes (
103+ x_range = [- 3 , 3 , 1 ],
104+ y_range = [- 2 , 8 , 2 ],
105+ x_length = 7 ,
106+ y_length = 5 ,
107+ axis_config = {"color" : WHITE },
108+ tips = False
109+ )
110+
111+ # Add labels
112+ x_label = axes .get_x_axis_label ("x" , edge = RIGHT , direction = RIGHT )
113+ y_label = axes .get_y_axis_label ("y" , edge = UP , direction = UP )
114+
115+ axes_group = VGroup (axes , x_label , y_label )
116+ axes_group .shift (RIGHT * 1.5 )
117+ axes_group .shift (DOWN * 1 )
118+
119+
120+ self .play (Create (axes ), Write (x_label ), Write (y_label ))
121+ self .wait (1 )
122+
123+ # Generate sample data points (quadratic with noise)
124+ np .random .seed (42 )
125+ x_data = np .array ([- 2.5 , - 2 , - 1.5 , - 1 , - 0.5 , 0 , 0.5 , 1 , 1.5 , 2 , 2.5 ])
126+ theta_true = np .array ([1 , 0 , 0.5 ])
127+ y_true = theta_true [0 ] + theta_true [1 ] * x_data + theta_true [2 ] * x_data ** 2 # True quadratic function
128+ y_data = y_true + np .random .normal (0 , 0.6 , len (x_data ))
129+
130+ # Draw the real distribution
131+ real_distribution = axes .plot (
132+ lambda x : theta_true [0 ] + theta_true [1 ] * x + theta_true [2 ] * x ** 2 ,
133+ color = WHITE ,
134+ x_range = (- 3 , 3 )
135+ )
136+
137+ # 0. REAL DISTRIBUTION
138+ real_label = Text ("Real Distribution" , font_size = 28 , color = WHITE )
139+ real_label .to_edge (LEFT ).shift (UP * 1.5 )
140+
141+ definition_real = MathTex (r"f_{\boldsymbol{\theta}} \in \mathcal{F}_2" , font_size = 28 )
142+ definition_real .next_to (real_label , DOWN , aligned_edge = LEFT )
143+
144+ equation_real = MathTex (r"y = \theta_0 + \theta_1 x + \theta_2 x^2 + \epsilon" , font_size = 28 )
145+ equation_real .next_to (definition_real , DOWN , aligned_edge = LEFT )
146+
147+ self .play (Create (real_distribution ), Write (real_label ))
148+ self .play (Write (definition_real ), Write (equation_real ))
149+
150+ # Create dots for data points
151+ dots = VGroup ()
152+ for x , y in zip (x_data , y_data ):
153+ dot = Dot (axes .c2p (x , y ), color = WHITE , radius = 0.08 )
154+ dots .add (dot )
155+
156+ # Animate data points appearing
157+ self .play (LaggedStart (* [GrowFromCenter (dot ) for dot in dots ], lag_ratio = 0.2 ))
158+ self .wait (1 )
159+
160+ # Draw Errors from Noise
161+ error_lines = VGroup ()
162+ for x , y , y_t in zip (x_data , y_data , y_true ):
163+ # Draw vertical line from true value to observed value
164+ start_point = axes .c2p (x , y_t )
165+ end_point = axes .c2p (x , y )
166+ error_line = Line (start_point , end_point , color = YELLOW , stroke_width = 2 )
167+ error_lines .add (error_line )
168+
169+ # Add label for errors
170+ error_label = MathTex (r"\epsilon \sim \mathcal{N}(0, \sigma^2)" , font_size = 28 , color = YELLOW )
171+ error_label .next_to (equation_real , DOWN , aligned_edge = LEFT )
172+
173+ self .play (LaggedStart (* [Create (line ) for line in error_lines ], lag_ratio = 0.1 ))
174+ self .play (Write (error_label ))
175+ self .wait (2 )
176+
177+ # Fade out error visualization
178+ self .play (FadeOut (error_lines ), FadeOut (error_label ))
179+ self .wait (0.5 )
180+
181+ # Title
182+ title = Text ("Underfitting vs Good Fit vs Overfitting" , font_size = 32 )
183+ title .to_edge (UP )
184+ self .play (Write (title ))
185+ self .wait (0.5 )
186+
187+ self .play (FadeOut (real_distribution ), FadeOut (real_label ), FadeOut (definition_real ), FadeOut (equation_real ))
188+
189+ # 1. UNDERFITTING - Linear model (too simple)
190+ underfit_label = Text ("Underfitting (Too Simple)" , font_size = 28 , color = RED )
191+ underfit_label .to_edge (LEFT ).shift (UP * 1.5 )
192+
193+ definition_underfit = MathTex (r"f_{\boldsymbol{\theta}} \in \mathcal{F}_1" , font_size = 32 )
194+ definition_underfit .next_to (underfit_label , DOWN , aligned_edge = LEFT )
195+
196+ equation_underfit = MathTex (r"\hat{y} = \theta_0 + \theta_1 x" , font_size = 32 )
197+ equation_underfit .next_to (definition_underfit , DOWN , aligned_edge = LEFT )
198+
199+ self .play (Write (underfit_label ), Write (definition_underfit ), Write (equation_underfit ))
200+
201+ # Fit a linear model (underfitting)
202+ A_linear = np .vstack ([np .ones (len (x_data )), x_data ]).T
203+ theta_linear = np .linalg .lstsq (A_linear , y_data , rcond = None )[0 ]
204+
205+ underfit_line = axes .plot (
206+ lambda x : theta_linear [0 ] + theta_linear [1 ] * x ,
207+ color = RED ,
208+ x_range = [- 3 , 3 ]
209+ )
210+
211+ params_underfit = MathTex (
212+ f"\\ theta_0^*={ theta_linear [0 ]:.2f} , \\ theta_1^*={ theta_linear [1 ]:.2f} " ,
213+ font_size = 28 ,
214+ color = RED
215+ )
216+ params_underfit .next_to (equation_underfit , DOWN , aligned_edge = LEFT )
217+
218+ self .play (Create (underfit_line ), Write (params_underfit ))
219+ self .wait (2 )
220+
221+ # Fade out underfitting
222+ self .play (
223+ FadeOut (underfit_line ),
224+ FadeOut (underfit_label ),
225+ FadeOut (definition_underfit ),
226+ FadeOut (equation_underfit ),
227+ FadeOut (params_underfit )
228+ )
229+ self .wait (0.5 )
230+
231+ # 2. GOOD FIT - Quadratic model (just right)
232+ goodfit_label = Text ("Good Fit (Just Right)" , font_size = 28 , color = GREEN )
233+ goodfit_label .to_edge (LEFT ).shift (UP * 1.5 )
234+
235+ definition_goodfit = MathTex (r"f_{\boldsymbol{\theta}} \in \mathcal{F}_2" , font_size = 32 )
236+ definition_goodfit .next_to (goodfit_label , DOWN , aligned_edge = LEFT )
237+
238+ equation_goodfit = MathTex (r"\hat{y} = \theta_0 + \theta_1 x + \theta_2 x^2" , font_size = 32 )
239+ equation_goodfit .next_to (definition_goodfit , DOWN , aligned_edge = LEFT )
240+
241+ self .play (Write (goodfit_label ), Write (definition_goodfit ), Write (equation_goodfit ))
242+
243+ # Fit a quadratic model (good fit)
244+ A_quad = np .vstack ([np .ones (len (x_data )), x_data , x_data ** 2 ]).T
245+ theta_quad = np .linalg .lstsq (A_quad , y_data , rcond = None )[0 ]
246+
247+ goodfit_curve = axes .plot (
248+ lambda x : theta_quad [0 ] + theta_quad [1 ] * x + theta_quad [2 ] * x ** 2 ,
249+ color = GREEN ,
250+ x_range = [- 3 , 3 ]
251+ )
252+
253+ params_goodfit = MathTex (
254+ f"\\ theta_0^*={ theta_quad [0 ]:.2f} , \\ theta_1^*={ theta_quad [1 ]:.2f} , \\ theta_2^*={ theta_quad [2 ]:.2f} " ,
255+ font_size = 28 ,
256+ color = GREEN
257+ )
258+ params_goodfit .next_to (equation_goodfit , DOWN , aligned_edge = LEFT )
259+
260+ self .play (Create (goodfit_curve ), Write (params_goodfit ))
261+ self .wait (2 )
262+
263+ # Fade out good fit
264+ self .play (
265+ FadeOut (goodfit_curve ),
266+ FadeOut (goodfit_label ),
267+ FadeOut (definition_goodfit ),
268+ FadeOut (equation_goodfit ),
269+ FadeOut (params_goodfit )
270+ )
271+ self .wait (0.5 )
272+
273+ # 3. OVERFITTING - High degree polynomial (too complex)
274+ overfit_label = Text ("Overfitting (Too Complex)" , font_size = 28 , color = ORANGE )
275+ overfit_label .to_edge (LEFT ).shift (UP * 2 )
276+
277+ definition_overfit = MathTex (r"f_{\boldsymbol{\theta}} \in \mathcal{F}_{10}" , font_size = 32 )
278+ definition_overfit .next_to (overfit_label , DOWN , aligned_edge = LEFT )
279+
280+ equation_overfit = MathTex (
281+ r"\hat{y} = \theta_0 + \theta_1 x + \cdots + \theta_{10} x^{10}" ,
282+ font_size = 32
283+ )
284+ equation_overfit .next_to (definition_overfit , DOWN , aligned_edge = LEFT )
285+
286+ self .play (Write (overfit_label ), Write (definition_overfit ), Write (equation_overfit ))
287+
288+ # Fit a 10th degree polynomial (overfitting)
289+ degree = 10
290+ A_poly = np .vstack ([x_data ** i for i in range (degree + 1 )]).T
291+ theta_poly = np .linalg .lstsq (A_poly , y_data , rcond = None )[0 ]
292+
293+ def poly_func (x ):
294+ return sum (theta_poly [i ] * x ** i for i in range (degree + 1 ))
295+
296+ overfit_curve = axes .plot (
297+ poly_func ,
298+ color = ORANGE ,
299+ x_range = [- 2.5 , 2.5 ],
300+ use_smoothing = True
301+ )
302+
303+ self .play (Create (overfit_curve ))
304+ self .wait (2 )
305+
306+ # Show all three together for comparison
307+ comparison_label = Text ("Comparison" , font_size = 32 , color = YELLOW )
308+ comparison_label .next_to (title , DOWN )
309+
310+ self .play (
311+ FadeOut (overfit_label ),
312+ FadeOut (definition_overfit ),
313+ FadeOut (equation_overfit ),
314+ Write (comparison_label )
315+ )
316+
317+ # Recreate all three curves
318+ underfit_line_final = axes .plot (
319+ lambda x : theta_linear [0 ] + theta_linear [1 ] * x ,
320+ color = RED ,
321+ x_range = [- 3 , 3 ],
322+ stroke_width = 3
323+ )
324+
325+ goodfit_curve_final = axes .plot (
326+ lambda x : theta_quad [0 ] + theta_quad [1 ] * x + theta_quad [2 ] * x ** 2 ,
327+ color = GREEN ,
328+ x_range = [- 3 , 3 ],
329+ stroke_width = 3
330+ )
331+
332+ # Add the real distribution back
333+ real_distribution_final = axes .plot (
334+ lambda x : theta_true [0 ] + theta_true [1 ] * x + theta_true [2 ] * x ** 2 ,
335+ color = WHITE ,
336+ x_range = [- 3 , 3 ],
337+ stroke_width = 4 ,
338+ stroke_opacity = 0.7
339+ )
340+
341+ self .play (
342+ Create (underfit_line_final ),
343+ Create (goodfit_curve_final ),
344+ Create (real_distribution_final )
345+ )
346+
347+ # Add legend
348+ legend_under = VGroup (
349+ Line (ORIGIN , RIGHT * 0.5 , color = RED , stroke_width = 3 ),
350+ Text ("Underfit" , font_size = 20 , color = RED )
351+ ).arrange (RIGHT , buff = 0.2 )
352+
353+ legend_good = VGroup (
354+ Line (ORIGIN , RIGHT * 0.5 , color = GREEN , stroke_width = 3 ),
355+ Text ("Good Fit" , font_size = 20 , color = GREEN )
356+ ).arrange (RIGHT , buff = 0.2 )
357+
358+ legend_over = VGroup (
359+ Line (ORIGIN , RIGHT * 0.5 , color = ORANGE , stroke_width = 3 ),
360+ Text ("Overfit" , font_size = 20 , color = ORANGE )
361+ ).arrange (RIGHT , buff = 0.2 )
362+
363+ legend_real = VGroup (
364+ Line (ORIGIN , RIGHT * 0.5 , color = WHITE , stroke_width = 4 , stroke_opacity = 0.7 ),
365+ Text ("Real Distribution" , font_size = 20 , color = WHITE )
366+ ).arrange (RIGHT , buff = 0.2 )
367+
368+ legend = VGroup (legend_real , legend_under , legend_good , legend_over ).arrange (
369+ DOWN , aligned_edge = LEFT , buff = 0.2
370+ )
371+ legend .to_edge (LEFT ).shift (DOWN * 1.5 )
372+
373+ self .play (FadeIn (legend ))
374+ self .wait (3 )
375+
376+ # Error Analysis Section
377+ self .play (
378+ FadeOut (comparison_label ),
379+ FadeOut (legend )
380+ )
381+
382+ error_title = Text ("Error Analysis: Residuals vs Noise" , font_size = 32 , color = YELLOW )
383+ error_title .next_to (title , DOWN )
384+ self .play (Write (error_title ))
385+ self .wait (1 )
386+
387+ # Calculate errors for each model
388+ y_pred_linear = theta_linear [0 ] + theta_linear [1 ] * x_data
389+ y_pred_quad = theta_quad [0 ] + theta_quad [1 ] * x_data + theta_quad [2 ] * x_data ** 2
390+ y_pred_poly = np .array ([poly_func (x ) for x in x_data ])
391+
392+ # Calculate MSE for each model and noise
393+ mse_noise = np .mean ((y_data - y_true )** 2 )
394+ mse_underfit = np .mean ((y_data - y_pred_linear )** 2 )
395+ mse_goodfit = np .mean ((y_data - y_pred_quad )** 2 )
396+ mse_overfit = np .mean ((y_data - y_pred_poly )** 2 )
397+
398+ # Create error comparison table
399+ error_labels = VGroup (
400+ MathTex (r"\lVert\epsilon\rVert_2^2" , font_size = 28 , color = YELLOW ),
401+ MathTex (r"\lVert r_{\text{underfit}}\rVert_2^2" , font_size = 28 , color = RED ),
402+ MathTex (r"\lVert r_{\text{good fit}}\rVert_2^2" , font_size = 28 , color = GREEN ),
403+ MathTex (r"\lVert r_{\text{overfit}}\rVert_2^2" , font_size = 28 , color = ORANGE )
404+ ).arrange (DOWN , aligned_edge = LEFT , buff = 0.3 )
405+
406+ error_values = VGroup (
407+ MathTex (f" = { mse_noise :.3f} " , font_size = 28 , color = YELLOW ),
408+ MathTex (f" = { mse_underfit :.3f} " , font_size = 28 , color = RED ),
409+ MathTex (f" = { mse_goodfit :.3f} " , font_size = 28 , color = GREEN ),
410+ MathTex (f" = { mse_overfit :.3f} " , font_size = 28 , color = ORANGE )
411+ ).arrange (DOWN , aligned_edge = LEFT , buff = 0.3 )
412+
413+ # Align the values with their corresponding labels
414+ for i , (label , value ) in enumerate (zip (error_labels , error_values )):
415+ value .align_to (label , UP )
416+
417+ # Position labels and values
418+ error_labels .to_edge (LEFT ).shift (UP * 0.5 )
419+ error_values .next_to (error_labels , RIGHT , buff = 0.5 )
420+
421+ self .play (
422+ Write (error_labels ),
423+ Write (error_values )
424+ )
425+ self .wait (4 )
426+
427+ # Fade out everything
428+ self .play (* [FadeOut (mob ) for mob in self .mobjects ])
429+ self .wait (1 )
430+
431+
99432class BinaryClassificationSimple (Scene ):
100433 def construct (self ):
101434
@@ -230,7 +563,7 @@ def construct(self):
230563 )
231564
232565 best_params = MathTex (
233- f"\\ theta_0={ theta_0_best :.1f} , \\ theta_1={ theta_1_best :.1f} , \\ theta_2={ theta_2_best :.1f} " ,
566+ f"\\ theta_0^* ={ theta_0_best :.1f} , \\ theta_1^* ={ theta_1_best :.1f} , \\ theta_2^* ={ theta_2_best :.1f} " ,
234567 font_size = 32 ,
235568 color = GREEN
236569 )
0 commit comments