@@ -59,7 +59,7 @@ def check_points(dataset, split_ratios, split_names, i: int):
5959 LOGGER .info ("" )
6060 overflows = [(pn , s ) if ps > s else (None , None ) for (pn , ps ), s in zip (sorted_points , sorted (split_ratios , reverse = True ))]
6161 overflow_point = next ((i , pn , s ) for i , (pn , s ) in enumerate (overflows ) if pn is not None )
62- dataset , name_split_map , cluster_split_map , split_ratios , split_names = assign_cluster (dataset , overflow_point [1 ], split_ratios , split_names , overflow_point [0 ])
62+ dataset , name_split_map , cluster_split_map , split_ratios , split_names = assign_cluster (dataset , overflow_point [1 ], split_ratios , split_names , overflow_point [0 ], clusters = False )
6363 return dataset , name_split_map , cluster_split_map , split_ratios , split_names
6464
6565
@@ -81,19 +81,22 @@ def check_clusters(dataset, split_ratios, split_names, strategy: Literal["break"
8181 return dataset , name_split_map , cluster_split_map , split_ratios , split_names
8282
8383
84- def assign_cluster (dataset : DataSet , cluster_name : Any , split_ratios , split_names , split_index ) -> DataSet :
84+ def assign_cluster (dataset : DataSet , cluster_name : Any , split_ratios , split_names , split_index , clusters : bool = True ) -> DataSet :
8585 split_name = split_names [split_index ]
8686 split_ratios = split_ratios [:split_index ] + split_ratios [split_index + 1 :]
8787 split_names = split_names [:split_index ] + split_names [split_index + 1 :]
8888
89- if dataset . cluster_map is not None :
89+ if clusters :
9090 cluster_index = dataset .cluster_names .index (cluster_name )
9191 name_split_map = {}
9292 cluster_split_map = {cluster_name : split_name }
9393 for n in dataset .names :
9494 if dataset .cluster_map [n ] == cluster_name :
9595 name_split_map [n ] = split_name
9696 dataset .cluster_names = dataset .cluster_names [:cluster_index ] + dataset .cluster_names [cluster_index + 1 :]
97+ del dataset .cluster_weights [cluster_name ]
98+ if dataset .cluster_stratification is not None :
99+ del dataset .cluster_stratification [cluster_name ]
97100 if dataset .cluster_similarity is not None :
98101 dataset .cluster_similarity = np .delete (dataset .cluster_similarity , cluster_index , axis = 0 )
99102 dataset .cluster_similarity = np .delete (dataset .cluster_similarity , cluster_index , axis = 1 )
@@ -105,6 +108,9 @@ def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_name
105108 cluster_split_map = {}
106109 name_index = dataset .names .index (cluster_name )
107110 dataset .names = dataset .names [:name_index ] + dataset .names [name_index + 1 :]
111+ del dataset .weights [cluster_name ]
112+ if dataset .stratification is not None :
113+ del dataset .stratification [cluster_name ]
108114 if dataset .similarity is not None :
109115 dataset .similarity = np .delete (dataset .similarity , name_index , axis = 0 )
110116 dataset .similarity = np .delete (dataset .similarity , name_index , axis = 1 )
0 commit comments