Allow multiple datasets in fit_to_data and add option to return opt_state#210
Allow multiple datasets in fit_to_data and add option to return opt_state#210aseyboldt wants to merge 4 commits intodanielward27:mainfrom
Conversation
|
I can see this could be useful, e.g. it would give a simple way to support weighted samples (#211). I appreciate the effort to maintain backwards compatibility, but it does lead to a bit of an ugly interface. If rewriting from scratch, you could consider replacing To me there isn't really a compelling reason to me why condition should be passed as a separate argument once we have chosen to allow Ability to pass and return the optimizer state is a good idea, but if we include it in Although I feel a little unsure about it, I think I am happy to merge this with the addition of the following changes:
I'm definitely open to feedback on this though if you or anyone else has any thoughts: there is inevitably a trade off between maintaining backwards compatibility and simplifying the code and improving the API. Users are probably better posed to discuss this than I am - in my applications I don't really mind quick to fix breaking changes (in both FlowJAX and other dependencies), but I understand that may be different for others. |
|
An issue I have just thought about, is it may be required to support |
954fdd7 to
4acaa15
Compare
|
Yeah, I also wasn't too happy about introducing the |
4acaa15 to
06bf498
Compare
06bf498 to
ff3d30e
Compare
|
Hmm, I'll have to have a think. I wouldn't merge this as is because of the breaking change. Maybe it's possible to add x as a key word only argument and handle it appropriately to give a deprecation warning without any breaking changes. Also I think the aforementioned issue, of a loss requiring the following data |
|
How about this version? This avoids the breaking change by having the old |
|
Didn't mean to close... |
|
It is a bit fiddly. Not deprecating the condition would be too confusing/messy of an interface I think if we were to go down this route. I gave it a bit of a go here: https://github.com/danielward27/flowjax/tree/multiple_arrays_fit (just focusing on multiple arrays, not adding It seems to work without breaking changes. Maybe it's important to support Alternatively you could argue users should be forced to e.g. wrap the function to remove any non-array arguments (in which case some of my changes are unnecessary). However, this is arguably a little confusing, especially if we have a loss needing arrays I am still a bit uncertain about whether to support this. The By the way, you have a |
|
I like your version :-) |
Some loss functions can require several arrays instead of only one.
This extends
fit_to_dataso that it passes batches of those to the loss functions.It can also be quite useful to reuse the optimizer state across different runs, so I added
opt_stateandreturn_opt_statearguments tofit_to_datato facilitate that.