Skip to content

Commit 583a641

Browse files
Another attempt at fixing that bloody error
1 parent 212db8e commit 583a641

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

datasail/solver/overflow.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def check_points(dataset, split_ratios, split_names, i: int):
5959
LOGGER.info("")
6060
overflows = [(pn, s) if ps > s else (None, None) for (pn, ps), s in zip(sorted_points, sorted(split_ratios, reverse=True))]
6161
overflow_point = next((i, pn, s) for i, (pn, s) in enumerate(overflows) if pn is not None)
62-
dataset, name_split_map, cluster_split_map, split_ratios, split_names = assign_cluster(dataset, overflow_point[1], split_ratios, split_names, overflow_point[0])
62+
dataset, name_split_map, cluster_split_map, split_ratios, split_names = assign_cluster(dataset, overflow_point[1], split_ratios, split_names, overflow_point[0], clusters=False)
6363
return dataset, name_split_map, cluster_split_map, split_ratios, split_names
6464

6565

@@ -81,19 +81,22 @@ def check_clusters(dataset, split_ratios, split_names, strategy: Literal["break"
8181
return dataset, name_split_map, cluster_split_map, split_ratios, split_names
8282

8383

84-
def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_names, split_index) -> DataSet:
84+
def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_names, split_index, clusters: bool = True) -> DataSet:
8585
split_name = split_names[split_index]
8686
split_ratios = split_ratios[:split_index] + split_ratios[split_index + 1:]
8787
split_names = split_names[:split_index] + split_names[split_index + 1:]
8888

89-
if dataset.cluster_map is not None:
89+
if clusters:
9090
cluster_index = dataset.cluster_names.index(cluster_name)
9191
name_split_map = {}
9292
cluster_split_map = {cluster_name: split_name}
9393
for n in dataset.names:
9494
if dataset.cluster_map[n] == cluster_name:
9595
name_split_map[n] = split_name
9696
dataset.cluster_names = dataset.cluster_names[:cluster_index] + dataset.cluster_names[cluster_index + 1:]
97+
del dataset.cluster_weights[cluster_name]
98+
if dataset.cluster_stratification is not None:
99+
del dataset.cluster_stratification[cluster_name]
97100
if dataset.cluster_similarity is not None:
98101
dataset.cluster_similarity = np.delete(dataset.cluster_similarity, cluster_index, axis=0)
99102
dataset.cluster_similarity = np.delete(dataset.cluster_similarity, cluster_index, axis=1)
@@ -105,6 +108,9 @@ def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_name
105108
cluster_split_map = {}
106109
name_index = dataset.names.index(cluster_name)
107110
dataset.names = dataset.names[:name_index] + dataset.names[name_index + 1:]
111+
del dataset.weights[cluster_name]
112+
if dataset.stratification is not None:
113+
del dataset.stratification[cluster_name]
108114
if dataset.similarity is not None:
109115
dataset.similarity = np.delete(dataset.similarity, name_index, axis=0)
110116
dataset.similarity = np.delete(dataset.similarity, name_index, axis=1)

0 commit comments

Comments
 (0)