Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/tinker/lib/public_interfaces/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,24 @@ async def forward_backward_custom_async(
import torch

# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
_loss_fn_input_keys = data[0].loss_fn_inputs.keys()
if "advantages" in _loss_fn_input_keys:
# Check for RL (`importance_sampling`) loss inputs
assert "advantages" in _loss_fn_input_keys, "advantages must be in loss_fn_inputs"
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
assert "logprobs" in _loss_fn_input_keys, "logprobs must be in loss_fn_inputs"
_loss_fn = "importance_sampling"

elif "weights" in _loss_fn_input_keys:
# Check for supervised learning loss inputs
assert "weights" in _loss_fn_input_keys, "weights must be in loss_fn_inputs"
assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs"
_loss_fn = "cross_entropy"
else:
assert False, "Invalid loss function inputs"

# Compute on-policy logprobs
forward_future = await self.forward_async(data, _loss_fn)
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
for out in forward_result.loss_fn_outputs:
Expand All @@ -280,7 +297,7 @@ async def forward_backward_custom_async(
grads.append(logprob.grad)

linear_loss_data = []
for datum, grad in zip(data, grads):
for datum, grad in zip(data, grads, strict=True):
loss_fn_inputs: Any = {
"target_tokens": datum.loss_fn_inputs["target_tokens"],
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
Expand Down