diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index d99a90a..d0cb768 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -263,14 +263,32 @@ async def forward_backward_custom_async( import torch # First do a forward pass and get logprobs - forward_future = await self.forward_async(data, "cross_entropy") + # -> self.forward_async(data, loss_fn="cross_entropy") expects data to + # be Datum list with loss_fn_inputs containing "target_tokens" and "weights" + def convert_to_cross_entropy_datum(datum: types.Datum) -> types.Datum: + """ + Remove non-cross_entropy keys from loss_fn_inputs in tinker Datum + """ + _keys = ["target_tokens", "weights"] + # Add weights to loss_fn_inputs + # - Following https://github.com/thinking-machines-lab/tinker-cookbook/blob/5ae76f38111ec97e476b5f95b43903e675208b52/tinker_cookbook/rl/data_processing.py#L165 + # use advantage == 0 as a heuristic for tokens to ignore (weights = 0) + # - Also pass PyTorch tensor directly (will be converted to TensorData) + datum.loss_fn_inputs["weights"] = torch.ones_like( + datum.loss_fn_inputs["advantages"].to_torch() != 0 # type: ignore + ) + loss_fn_inputs = {k: v for k, v in datum.loss_fn_inputs.items() if k in _keys} + return types.Datum(model_input=datum.model_input, loss_fn_inputs=loss_fn_inputs) + + _data = list(map(convert_to_cross_entropy_datum, data)) + forward_future = await self.forward_async(_data, "cross_entropy") forward_result = await forward_future.result_async() logprobs_list: List[torch.Tensor] = [] for out in forward_result.loss_fn_outputs: logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True) logprobs_list.append(logprob) - # Now apply user-provided function + # Now apply user-provided function (on original data list) loss, metrics = loss_fn(data, logprobs_list) loss.backward() grads = [] @@ -280,7 +298,7 @@ async def forward_backward_custom_async( grads.append(logprob.grad) linear_loss_data = [] - for datum, grad in zip(data, grads): + for datum, grad in zip(data, grads, strict=True): loss_fn_inputs: Any = { "target_tokens": datum.loss_fn_inputs["target_tokens"], "weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)