2323from bigframes .ml import base , core , globals , utils
2424import bigframes .pandas as bpd
2525import third_party .bigframes_vendored .sklearn .preprocessing ._data
26+ import third_party .bigframes_vendored .sklearn .preprocessing ._discretization
2627import third_party .bigframes_vendored .sklearn .preprocessing ._encoder
2728import third_party .bigframes_vendored .sklearn .preprocessing ._label
2829
@@ -44,12 +45,15 @@ def __init__(self):
4445 def __eq__ (self , other : Any ) -> bool :
4546 return type (other ) is StandardScaler and self ._bqml_model == other ._bqml_model
4647
47- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
48+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
4849 """Compile this transformer to a list of SQL expressions that can be included in
4950 a BQML TRANSFORM clause
5051
5152 Args:
52- columns: a list of column names to transform
53+ columns:
54+ a list of column names to transform.
55+ X (default None):
56+ Ignored.
5357
5458 Returns: a list of tuples of (sql_expression, output_name)"""
5559 return [
@@ -124,12 +128,15 @@ def __init__(self):
124128 def __eq__ (self , other : Any ) -> bool :
125129 return type (other ) is MaxAbsScaler and self ._bqml_model == other ._bqml_model
126130
127- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
131+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
128132 """Compile this transformer to a list of SQL expressions that can be included in
129133 a BQML TRANSFORM clause
130134
131135 Args:
132- columns: a list of column names to transform
136+ columns:
137+ a list of column names to transform.
138+ X (default None):
139+ Ignored.
133140
134141 Returns: a list of tuples of (sql_expression, output_name)"""
135142 return [
@@ -204,12 +211,15 @@ def __init__(self):
204211 def __eq__ (self , other : Any ) -> bool :
205212 return type (other ) is MinMaxScaler and self ._bqml_model == other ._bqml_model
206213
207- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
214+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
208215 """Compile this transformer to a list of SQL expressions that can be included in
209216 a BQML TRANSFORM clause
210217
211218 Args:
212- columns: a list of column names to transform
219+ columns:
220+ a list of column names to transform.
221+ X (default None):
222+ Ignored.
213223
214224 Returns: a list of tuples of (sql_expression, output_name)"""
215225 return [
@@ -267,6 +277,124 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
267277 )
268278
269279
280+ class KBinsDiscretizer (
281+ base .Transformer ,
282+ third_party .bigframes_vendored .sklearn .preprocessing ._discretization .KBinsDiscretizer ,
283+ ):
284+ __doc__ = (
285+ third_party .bigframes_vendored .sklearn .preprocessing ._discretization .KBinsDiscretizer .__doc__
286+ )
287+
288+ def __init__ (
289+ self ,
290+ n_bins : int = 5 ,
291+ strategy : Literal ["uniform" , "quantile" ] = "quantile" ,
292+ ):
293+ if strategy != "uniform" :
294+ raise NotImplementedError (
295+ f"Only strategy = 'uniform' is supported now, input is { strategy } ."
296+ )
297+ if n_bins < 2 :
298+ raise ValueError (
299+ f"n_bins has to be larger than or equal to 2, input is { n_bins } ."
300+ )
301+ self .n_bins = n_bins
302+ self .strategy = strategy
303+ self ._bqml_model : Optional [core .BqmlModel ] = None
304+ self ._bqml_model_factory = globals .bqml_model_factory ()
305+ self ._base_sql_generator = globals .base_sql_generator ()
306+
307+ # TODO(garrettwu): implement __hash__
308+ def __eq__ (self , other : Any ) -> bool :
309+ return (
310+ type (other ) is KBinsDiscretizer
311+ and self .n_bins == other .n_bins
312+ and self ._bqml_model == other ._bqml_model
313+ )
314+
315+ def _compile_to_sql (
316+ self ,
317+ columns : List [str ],
318+ X : bpd .DataFrame ,
319+ ) -> List [Tuple [str , str ]]:
320+ """Compile this transformer to a list of SQL expressions that can be included in
321+ a BQML TRANSFORM clause
322+
323+ Args:
324+ columns:
325+ a list of column names to transform
326+ X:
327+ The Dataframe with training data.
328+
329+ Returns: a list of tuples of (sql_expression, output_name)"""
330+ array_split_points = {}
331+ if self .strategy == "uniform" :
332+ for column in columns :
333+ min_value = X [column ].min ()
334+ max_value = X [column ].max ()
335+ bin_size = (max_value - min_value ) / self .n_bins
336+ array_split_points [column ] = [
337+ min_value + i * bin_size for i in range (self .n_bins - 1 )
338+ ]
339+
340+ return [
341+ (
342+ self ._base_sql_generator .ml_bucketize (
343+ column , array_split_points [column ], f"kbinsdiscretizer_{ column } "
344+ ),
345+ f"kbinsdiscretizer_{ column } " ,
346+ )
347+ for column in columns
348+ ]
349+
350+ @classmethod
351+ def _parse_from_sql (cls , sql : str ) -> tuple [KBinsDiscretizer , str ]:
352+ """Parse SQL to tuple(KBinsDiscretizer, column_label).
353+
354+ Args:
355+ sql: SQL string of format "ML.BUCKETIZE({col_label}, array_split_points, FALSE) OVER()"
356+
357+ Returns:
358+ tuple(KBinsDiscretizer, column_label)"""
359+ s = sql [sql .find ("(" ) + 1 : sql .find (")" )]
360+ array_split_points = s [s .find ("[" ) + 1 : s .find ("]" )]
361+ col_label = s [: s .find ("," )]
362+ n_bins = array_split_points .count ("," ) + 2
363+ return cls (n_bins , "uniform" ), col_label
364+
365+ def fit (
366+ self ,
367+ X : Union [bpd .DataFrame , bpd .Series ],
368+ y = None , # ignored
369+ ) -> KBinsDiscretizer :
370+ (X ,) = utils .convert_to_dataframe (X )
371+
372+ compiled_transforms = self ._compile_to_sql (X .columns .tolist (), X )
373+ transform_sqls = [transform_sql for transform_sql , _ in compiled_transforms ]
374+
375+ self ._bqml_model = self ._bqml_model_factory .create_model (
376+ X ,
377+ options = {"model_type" : "transform_only" },
378+ transforms = transform_sqls ,
379+ )
380+
381+ # The schema of TRANSFORM output is not available in the model API, so save it during fitting
382+ self ._output_names = [name for _ , name in compiled_transforms ]
383+ return self
384+
385+ def transform (self , X : Union [bpd .DataFrame , bpd .Series ]) -> bpd .DataFrame :
386+ if not self ._bqml_model :
387+ raise RuntimeError ("Must be fitted before transform" )
388+
389+ (X ,) = utils .convert_to_dataframe (X )
390+
391+ df = self ._bqml_model .transform (X )
392+ return typing .cast (
393+ bpd .DataFrame ,
394+ df [self ._output_names ],
395+ )
396+
397+
270398class OneHotEncoder (
271399 base .Transformer ,
272400 third_party .bigframes_vendored .sklearn .preprocessing ._encoder .OneHotEncoder ,
@@ -308,13 +436,15 @@ def __eq__(self, other: Any) -> bool:
308436 and self .max_categories == other .max_categories
309437 )
310438
311- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
439+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
312440 """Compile this transformer to a list of SQL expressions that can be included in
313441 a BQML TRANSFORM clause
314442
315443 Args:
316444 columns:
317- a list of column names to transform
445+ a list of column names to transform.
446+ X (default None):
447+ Ignored.
318448
319449 Returns: a list of tuples of (sql_expression, output_name)"""
320450
@@ -432,13 +562,15 @@ def __eq__(self, other: Any) -> bool:
432562 and self .max_categories == other .max_categories
433563 )
434564
435- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
565+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
436566 """Compile this transformer to a list of SQL expressions that can be included in
437567 a BQML TRANSFORM clause
438568
439569 Args:
440570 columns:
441- a list of column names to transform
571+ a list of column names to transform.
572+ X (default None):
573+ Ignored.
442574
443575 Returns: a list of tuples of (sql_expression, output_name)"""
444576
0 commit comments