Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 81 additions & 44 deletions rdagent/scenarios/data_science/proposal/exp_gen/trace_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,18 @@ class MCTSScheduler(ProbabilisticScheduler):
- Keep NEW_ROOT policy and uncommitted status handling identical to base classes.
"""

ROOT_ID = -1

def __init__(self, max_trace_num: int, temperature: float = 1.0, *args, **kwargs):
super().__init__(max_trace_num, temperature)
# Read c_puct from settings if available, otherwise fall back to default 1.0
self.c_puct = getattr(DS_RD_SETTING, "scheduler_c_puct", 1.0) or 1.0
self.c_uct = getattr(DS_RD_SETTING, "scheduler_c_uct", 1.0) or 1.0
# Statistics keyed by leaf node index
self.node_visit_count: dict[int, int] = {}
self.node_value_sum: dict[int, float] = {}
self.node_prior: dict[int, float] = {}
# Global counter to stabilize U term
self.global_visit_count: int = 0

self.node_visit_count[self.ROOT_ID] = 1
self.node_value_sum[self.ROOT_ID] = 0.0

# Last observed commit index for batch feedback observation
self.last_observed_commit_idx: int = 0

Expand All @@ -349,49 +351,81 @@ def _get_q(self, node_id: int) -> float:
return 0.0
return value_sum / visits

def _get_u(self, node_id: int) -> float:
prior = self.node_prior.get(node_id, 0.0)
def _get_parents(self, node_id: int, trace: DSTrace) -> list[int]:
"""
Due to the MCTS algorithm will have a virtual root node, which does not exist in the trace data structure.
"""
if node_id == self.ROOT_ID:
parents = []
else:
parents = trace.get_parents(node_id)
parents_with_root = [self.ROOT_ID] + parents
return parents_with_root

def _get_all_nodes(self, trace: DSTrace) -> list[int]:
"""
Due to the MCTS algorithm will have a virtual root node, which does not exist in the trace data structure.
"""
return [self.ROOT_ID] + list(range(len(trace.hist)))

def _get_u_uct(self, node_id: int, trace: DSTrace) -> float:
parents = self._get_parents(node_id, trace)

if node_id == self.ROOT_ID:
last_parent_id = self.ROOT_ID
else:
last_parent_id = parents[-2]

parent_visits = self.node_visit_count.get(last_parent_id, 0)
visits = self.node_visit_count.get(node_id, 0)
# Avoid div-by-zero; encourage exploration when visits are small
return self.c_puct * prior * math.sqrt(max(1, self.global_visit_count)) / (1 + visits)
N = max(1, parent_visits)
n = max(1, visits)
return self.c_uct * math.sqrt(math.log(N) / n)

def select(self, trace: DSTrace) -> tuple[int, ...] | None:
def _select(self, trace: DSTrace) -> tuple[int, ...] | None:
# Step 1: keep same policy to reach target number of parallel traces
# TODO: expanding from the virtual root node is implemented in a rule-based way.
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
return trace.NEW_ROOT

# Step 2: consider only available leaves (not being expanded)
available_leaves = list(set(range(len(trace.hist))))
if not available_leaves:
return None

# Step 3: compute priors (P) from potentials via softmax
potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves]
if any(p < 0 for p in potentials):
raise ValueError("Potential function returned a negative value.")
priors = self._softmax_probabilities(potentials)
for leaf, p in zip(available_leaves, priors):
self.node_prior[leaf] = p
candidates = list(available_leaves) # copy
candidates_with_root = candidates + [self.ROOT_ID]

candidates_with_root = self._get_all_nodes(trace)

# Step 4: score each leaf using PUCT-like rule: Q + U
best_leaf = None
best_score = -float("inf")
for leaf in available_leaves:
q = self._get_q(leaf)
u = self._get_u(leaf)
score = q + u
if score > best_score:
best_score = score
best_leaf = leaf

if best_leaf is None:
score_id_pairs = [(self._get_q(nid) + self._get_u_uct(nid, trace), nid) for nid in candidates_with_root]
score_id_pairs.sort(reverse=True)

if len(score_id_pairs) == 0:
return None

# # Step 5: optimistic visit update on selection; value update deferred to observe_feedback
self.global_visit_count += 1
best_node, _ = score_id_pairs[0]

if best_node == self.ROOT_ID and len(score_id_pairs) > 1:
# Motivation: we sometimes want to limit the expansion of the root node.
# capacity full: pick next best real leaf
capacity = trace.sub_trace_count + self.uncommited_rec_status.get(trace.NEW_ROOT, 0)
if capacity >= self.max_trace_num:
second_best, _ = score_id_pairs[1]
return (second_best,)

return (best_leaf,)
return (best_node,)

def select(self, trace: DSTrace) -> tuple[int, ...] | None:
"""
In MCTS, we have a virtual root node, expanding from the virutal root node will return (-1,).
But in the trace DAG, expanding a new node from root node should return (trace.NEW_ROOT,).
"""
base_nodes = self._select(trace)
if base_nodes == (self.ROOT_ID,):
return trace.NEW_ROOT
return base_nodes

def sigmoid(self, x):
return 1 / (1 + math.exp(-x))

def scaled_tanh(self, x):
# tanh -> (-1,1), then scale to (0,1)
return (math.tanh(x) + 1.0) / 2.0

def observe_feedback(self, trace: DSTrace, new_idx: int) -> None:
"""
Expand All @@ -406,13 +440,18 @@ def observe_feedback(self, trace: DSTrace, new_idx: int) -> None:
re, fb = trace.hist[new_idx]
if DS_RD_SETTING.enable_score_reward:
bigger_is_better = get_metric_direction(trace.scen.competition)
if getattr(fb, "decision", False):
reward = math.tanh(re.result.loc["ensemble"].iloc[0].round(3)) * (1 if bigger_is_better else -1)
if re.result is not None:
if bigger_is_better:
reward = self.scaled_tanh(re.result.loc["ensemble"].iloc[0])
else:
reward = 1 - self.scaled_tanh(re.result.loc["ensemble"].iloc[0])
else:
reward = -1 if bigger_is_better else 1
reward = 0 if bigger_is_better else 1
else:
reward = 1.0 if getattr(fb, "decision", False) else 0.0
id_list = trace.get_parents(new_idx)

id_list = self._get_parents(new_idx, trace)

for id in id_list:
self.node_value_sum[id] = self.node_value_sum.get(id, 0.0) + float(reward)
self.node_visit_count[id] = self.node_visit_count.get(id, 0) + 1
Expand All @@ -424,8 +463,6 @@ def reset(self) -> None:
super().reset()
self.node_visit_count.clear()
self.node_value_sum.clear()
self.node_prior.clear()
self.global_visit_count = 0
self.last_observed_commit_idx = 0

def process_uncommitted_nodes(self, trace: DSTrace) -> None:
Expand Down
Loading