Conversation
Jammy2211
left a comment
There was a problem hiding this comment.
Don't merge yet -- I'll build a test case this week and see whether any issues come up.
Couple of things to think about, albeit best not to worry about them until we've done some testing:
- Does this all play nicely with our JAX / Pytrees wrapping?
- How does this play with search chaining API?
- Is there a clean interface with the default-prior yaml configs?
|
Need to check JAX and search chaining. We mostly do config on the type -> prior relationship so it was hard to see how to carry this over |
|
I would make it so that the config default gives the same prior to everything on the numpy array. We can worry about JAX and search chaining later on then! |
So at the moment you define one prior and that gets copied to every slot. Would we want a config to determine the default for all Arrays globally? |
|
Should now work with prior passing and JAX |
| ---------- | ||
| shape : (int, int) | ||
| The shape of the array. | ||
| prior : Prior |
There was a problem hiding this comment.
can you describe what happens if the parameter prior is not provided?
| """ | ||
| super().__init__() | ||
| self.shape = shape | ||
| self.indices = list(np.ndindex(*shape)) |
There was a problem hiding this comment.
you are converting self.indices to a list but you are usint the typing Tuple[int, ...] everywhere.
Adds the Array class which functions as a PriorModel that creates numpy arrays of floats.
An array is defined by its shape and a prior.
The prior is copied to each index meaning that a 2x2 array has four independent priors.
Arrays can be accessed and modified using indexing.
They can be instantiated just as any other model class