Skip to content

Custom rl loss patch 1#3

Open
mzio wants to merge 2 commits intothinking-machines-lab:mainfrom
mzio:custom_rl_loss_patch_1
Open

Custom rl loss patch 1#3
mzio wants to merge 2 commits intothinking-machines-lab:mainfrom
mzio:custom_rl_loss_patch_1

Conversation

@mzio
Copy link

@mzio mzio commented Oct 10, 2025

See #2 (comment)

Main issue: High-level, there seems to be a conflict between how a user would specify an RL loss and the required Datum loss_fn_inputs, and how this gets processed in training_client (where it expects supervised learning loss_fn_inputs).

@capture_exceptions(fatal=True)
async def forward_backward_custom_async(
self, data: List[types.Datum], loss_fn: CustomLossFnV1
) -> APIFuture[types.ForwardBackwardOutput]:
import torch
# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
for out in forward_result.loss_fn_outputs:
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
logprobs_list.append(logprob)
# Now apply user-provided function
loss, metrics = loss_fn(data, logprobs_list)

Solution here is to instead create a separate copy of the Datum list, where each element now has a weights keys in loss_fn_inputs. We infer 0 vs 1 based on whether advantages in the original datum list are 0 or not (maybe bad heuristic).

We can then use the copy to compute logprobs as currently, while then applying the user custom loss_fn to the original data, e.g.,:

# convert
_data_for_xent = list(map(convert_to_cross_entropy_datum, data))

# get on-policy logprobs
forward_future = await self.forward_async(_data_for_xent, "cross_entropy")
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
      for out in forward_result.loss_fn_outputs:
          logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
          logprobs_list.append(logprob)

# apply user-provided function (on original data list)
loss, metrics = loss_fn(data, logprobs_list)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant