Skip to content

Custom rl loss patch 2 (with key detection)#4

Open
mzio wants to merge 1 commit intothinking-machines-lab:mainfrom
mzio:custom_rl_loss_patch_2
Open

Custom rl loss patch 2 (with key detection)#4
mzio wants to merge 1 commit intothinking-machines-lab:mainfrom
mzio:custom_rl_loss_patch_2

Conversation

@mzio
Copy link

@mzio mzio commented Oct 10, 2025

See #2 (comment) and #3 (comment)

Main issue: High-level, there seems to be a conflict between how a user would specify an RL loss and the required Datum loss_fn_inputs, and how this gets processed in training_client (where it expects supervised learning loss_fn_inputs).

@capture_exceptions(fatal=True)
async def forward_backward_custom_async(
self, data: List[types.Datum], loss_fn: CustomLossFnV1
) -> APIFuture[types.ForwardBackwardOutput]:
import torch
# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
for out in forward_result.loss_fn_outputs:
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
logprobs_list.append(logprob)
# Now apply user-provided function
loss, metrics = loss_fn(data, logprobs_list)

Solution here is to instead try to detect if the user's using an RL loss based on the keys in the first Datum.loss_fn_inputs

  • If there's advantages, it also asserts for the other expected keys (target_tokens, logprobs)
  • And then computes the on-policy logprobs with forward_future = await self.forward_async(data, "importance_sampling")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant