diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index d99a90a..47aeffb 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -263,7 +263,24 @@ async def forward_backward_custom_async( import torch # First do a forward pass and get logprobs - forward_future = await self.forward_async(data, "cross_entropy") + _loss_fn_input_keys = data[0].loss_fn_inputs.keys() + if "advantages" in _loss_fn_input_keys: + # Check for RL (`importance_sampling`) loss inputs + assert "advantages" in _loss_fn_input_keys, "advantages must be in loss_fn_inputs" + assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs" + assert "logprobs" in _loss_fn_input_keys, "logprobs must be in loss_fn_inputs" + _loss_fn = "importance_sampling" + + elif "weights" in _loss_fn_input_keys: + # Check for supervised learning loss inputs + assert "weights" in _loss_fn_input_keys, "weights must be in loss_fn_inputs" + assert "target_tokens" in _loss_fn_input_keys, "target_tokens must be in loss_fn_inputs" + _loss_fn = "cross_entropy" + else: + assert False, "Invalid loss function inputs" + + # Compute on-policy logprobs + forward_future = await self.forward_async(data, _loss_fn) forward_result = await forward_future.result_async() logprobs_list: List[torch.Tensor] = [] for out in forward_result.loss_fn_outputs: @@ -280,7 +297,7 @@ async def forward_backward_custom_async( grads.append(logprob.grad) linear_loss_data = [] - for datum, grad in zip(data, grads): + for datum, grad in zip(data, grads, strict=True): loss_fn_inputs: Any = { "target_tokens": datum.loss_fn_inputs["target_tokens"], "weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)