Utils¶
Utility Functions¶
-
pysigma.utils.compatible_shape(msg_shape1, msg_shape2)[source]¶ Checks whether the two given message shapes are compatible.
Both msg_shape1 and msg_shape2 should be an iterable of torch.Size and have the contents
(batch_shape, param_shape, sample_shape, event_shape). An empty shape, i.e.,torch.Size([]), will be deemed compatible with any other shape. msg_shape1 is compatible with msg_shape2 if and only if all of its four entries are compatible with their counterpart in msg_shape2.- Parameters
msg_shape1 (tuple of torch.Size) – First shape. Should have the format
(batch_shape, param_shape, sample_shape, event_shape).msg_shape2 (tuple of torch.Size) – Second shape. Same as msg_shape1.
- Returns
True if both shape are compatible.
- Return type
bool
DistributionServer¶
-
class
pysigma.utils.DistributionServer[source]¶ Serving distribution class dependent utilities
Conversion between PyTorch distribution parameters and distribution instance:
param2dist(),dist2param()Translation between PyTorch distribution parameters and natural parameters for exponential family distribution:
natural2exp_dist(),exp_dist2natural()Get vector of moments from a given distribution instance:
get_moments()Draw particles from distribution instance:
draw_particles()Get log probability density from given particles:
log_pdf()
Certain distribution classes require special handling, for example for those categorized as finite discrete, particle values will be drawn uniformly, covering every value in the RV’s value range (support) once and only once, while assigning each particle its probability mass as its particle weight.
Therefore we delegate all such special handling to this class on an individual basis.
Note that input and output will conform to the format understandable by PyTorch’s distribution class. To translate to and from formats compatible to PySigma’s predicate knowledge, use KnowledgeServer class
-
classmethod
param2dist(dist_class, param, b_shape=None, e_shape=None, dist_info=None)[source]¶ Converts distribution parameter to a distribution instance.
Depending on the context and Predicate knowledge format, the parameter param may belong to different representation systems, in which case it should be interpreted differently. Such specification should be sufficiently described in the argument dist_info in a prior consent format.
The optional arguments b_shape and e_shape stand for distribution’s batch shape and event shape respectively. They are used primarily for sanity check. Note that this event shape e_shape pertains to PyTorch’s distribution class specification, and therefore may or may not be different from the event shape of particles in PySigma’s cognitive format.
- Parameters
dist_class (type) – The distribution class. Must be a subclass of
torch.distributions.Distribution.param (torch.Tensor) – The parameter tensor. The last dimension is assumed to be the parameter dimension, and sizes of the dimensions at the front should be equal to b_shape.
b_shape (torch.Size, optional) – The batch shape of the distribution. Used for shape sanity check.
e_shape (torch.Size, optional) – The presumed event shape of the distribution. Used for shape sanity check.
dist_info (dict, optional) – An optional dict containing all relevant information in order to correctly interpret the parameter param.
- Returns
The instantiated distribution instance.
- Return type
torch.distributions.Distribution
- Raises
NotImplementedError – If the conversion procedure specific to dist_class has not been implemented yet.
ValueError – If the converted distribution instance has different batch shape and event shape than specified b_shape and e_shape respectively.
-
classmethod
dist2param(dist, dist_info=None)[source]¶ Extract the parameter tensor from a given distribution instance.
Depending on the context and Predicate knowledge format, the desired parameter may belong to different representation systems, in which case it should be generated differently. Such specification should be sufficiently described in the argument dist_info in a prior consent format.
- Parameters
dist (torch.distributions.Distribution) – The distribution instance from which to extract the parameter.
dist_info (dict, optional) – An optional dict containing all relevant information in order to correctly generate the parameter param.
- Returns
The parameter tensor in the desired format.
- Return type
torch.Tensor
- Raises
NotImplementedError – If the conversion procedure specific to the distribution class of dist has not been implemented yet.
-
classmethod
get_moments(dist, n_moments)[source]¶ Get vector of moments from a given distribution instance
Todo
Implement with dist_info
-
classmethod
draw_particles(dist, num_particles, dist_info=None)[source]¶ Draw a list of num_particles event particles from the given distribution specified by dist. The event particles drawn will be in the format compatible with DistributionServer and PyTorch.
- Parameters
dist (torch.distributions.Distribution) – The distribution instance from which to sample particles
num_particles (int) – The number of particles to be drawn
dist_info (dict, optional) – Additional dist info necessary for drawing particles in the correct format.
- Returns
the list of particles drawn, of shape
[num_particles, event_size]- Return type
torch.Tensor
- Raises
NotImplementedError – If the given distribution yields multi-dimensional events, and no corresponding special drawing method is found in
cls.dict_draw_particlesmethod map.
Notes
Unless distribution-class-specific drawing method is specified and registered in
cls.dict_draw_particlesmethod map, the distribution instance dist will be directly queried to draw the list of samples.The distribution instance dist is assumed batched, with a variable batch size(shape). However, we want to draw a single unique list of particles that is representative of each and every single distribution in the batch, i.e., draw a list of particles from the joint distribution regardless of the batch dimensions. Therefore, we take the view that drawing samples from dist simultaneously across the batch, which results in a sample tensor that involves the batch dimension, and ignoring the batch dimensions, is equivalent to first selecting uniformly which single distribution in the batch we wish draw from, and drawing samples from it, and repeating this process over and over again. The latter approach, when the samples are aggregated, yields a particle list that is representative of the joint distribution of the whole batch.
Accordingly, the sampling process is implemented by drawing n batched samples from dist, where
n = num_particles // batch_size + 1, collapses the batch dimensions, random shuffle across the collapsed sample dimension, and truncate to select only a number ofnum_particlessamples.
-
classmethod
log_prob(dist, values)[source]¶ Get the log probability mass/density of the given particle values w.r.t. the given batched distribution instance.
The particle value should be in PyTorch format that is compatible with PyTorch’s distribution classes. This means the last dimension of values is assumed the event dimension, and should be compatible with, if not identical to,
dist.event_shape. Every other dimensions to the front is assumed sample dimensions, the sizes of which together forms thesample_shape.The distribution instance dist is assumed batched. In other words, its batch shape
dist.batch_shapeshould not be empty.- Parameters
dist (torch.distributions.Distribution) – A batched distribution instance. Its batch shape
dist.batch_shapeshould not be empty.values (torch.Tensor) – A tensor with shape
(sample_shape + [event_size])
- Returns
The log probability mass/density tensor, of shape
(dist.batch_shape + sample_shape)- Return type
torch.Tensor
- Raises
AssertionError – If the
event_sizefound in values is different from dist.event_shape.
-
classmethod
kl_norm(dist1, dist2)[source]¶ Get the norm of the KL divergence of two given batched distributions
-
classmethod
transform_param(param, dist_info, trans)[source]¶ Todo
To implement
Return the parameter of the transformed distribution
-
dict_draw_particles= {<class 'torch.distributions.categorical.Categorical'>: <classmethod object>}¶
-
dict_log_pdf= {}¶
-
dict_param2dist= {}¶
-
dict_dist2param= {}¶
-
dict_natural2exp_param= {}¶
-
dict_exp_param2natural= {}¶
-
dict_natural2exp_dist= {}¶
-
dict_exp_dist2natural= {}¶
-
dict_get_moments= {}¶
KnowledgeServer¶
-
class
pysigma.utils.KnowledgeServer(dist_class, rv_sizes, rv_constraints, rv_num_particles, dist_info=None)[source]¶ Knowledge Server class. Provides service regarding a Predicate’s knowledge.
The architecture should hold one KnowledgeServer instance for each Predicate instantiated to cache knowledge contents and provide distribution related service.
- Parameters
dist_class (type) – The distribution class of the Predicate’s knowledge. Must be a subclass of
torch.distributions.Distribution.rv_sizes (iterable of int) – The sizes of the random variables of the Predicate’s knowledge. Note that the order given by the iterable will be respected.
rv_constraints (iterable of torch.distributions.constraints.Constraint) – The value constraints of the random variables. Note that the order given by the iterable will be respected.
rv_num_particles (iterable of int) – The number of marginal particles that should be drawn w.r.t. each random variable. Must have the same length as rv_sizes and rv_constraints, i.e., the number of random variables.
dist_info (dict, optional) – An optional attribute dict that contains all necessary information for DistributionServer to draw particles and query particles’ log pdf.
-
dist_class¶ - Type
type
-
rv_sizes¶ - Type
tuple of int
-
rv_constraints¶ - Type
tuple of torch.distributions.constraints.Constraint
-
rv_num_particles¶ - Type
tuple of int
-
dist_info¶ - Type
dict
-
num_rvs¶ Number of random variables involved in specifying the Predicate knowledge.
- Type
int
-
e_shape¶ The event shape of predicate’s knowledge. Inferred form rv_sizes.
- Type
torch.Size
-
particles¶ The cached tuple of marginal particle event tensors corresponding to the random variables. This attribute is set when draw_grid_particles is called with
update_cache=True.- Type
tuple of torch.Tensor
-
log_densities¶ The cached tuple of log sampling density tensors corresponding to each of the marginal particle event. This attribute is set when draw_grid_particles is called with
update_cache=True.- Type
tuple of torch.Tensor
Notes
In order to provide service to both predicate nodes and conditional nodes in all stages, KnowledgeServer should store and manipulate data regarding the random variables only. In other words, only message components that do not involve batch dimensions should be cached; this includes particle value tensors and log sampling density tensors, but excludes both parameter and weight tensors. The latter ones’ shapes are not invariant throughout the stages in the conditional subgraph, and therefore should be specified by the callee.
Signatures for special private distribution class dependent methods:
Cognitive to PyTorch event format translation method:
2torch_event(particles) --> particlesPyTorch to Cognitive event format translation method:
2cognitive_event(particles) --> particlesSpecial marginal particle list sampling method:
special_draw(batched_dist) --> particles, log_densities
-
draw_particles(batched_param, batch_shape, update_cache=True)[source]¶ Draws new particles for the associated predicate w.r.t. the given batched_param. Returns necessary components to instantiate a particles message.
This method is typically called by the predicate’s LTMFN node during modification phase, in which the new updated batched parameter tensor has been obtained and provided by batched_param. This method is then proceed to:
instantiate the batched distribution instances from the batched parameter tensor,
draw a single unique list of marginal particle values w.r.t. each random variable from the entire batch of distribution instances,
calculate their corresponding marginal sampling densities,
- Parameters
batched_param (torch.Tensor) – The new batched parameter tensor, of shape
(batch_shape + [param_size]).batch_shape (torch.Size) – The batch shape
update_cache (bool) – Whether to replace the cache content in
self.particlesandself.log_densitieswith the result of calling this method.
- Returns
particles (tuple of torch.Tensor) – The marginal particle lists w.r.t. each random variable in order.
log_densities (tuple of torch.Tensor) – The marginal sampling log densities w.r.t. each random variable in order.
Notes
Some remarks regarding the aforementioned step 2 and 3:
The tuple set of the types of the rv constraints specified in
self.rv_constraintswill be used to look up the pre-specified method mapself.dict_2special_draw. If an entry present, will used that method to obtain the returningparticlesandlog_densities. This is particularly useful, for instance, in the case of finite discrete random variables where a regular lattice should be drawn uniformly.Otherwise, the standard procedure will be carried out. ``
-
surrogate_log_prob(param, alt_particles=None, index_map=None)[source]¶ Query the log pdf of the surrogate particles specified by alt_particles w.r.t. the cached distribution instance.
A batched distribution instance will be instantiated from param, along with registered metadata in self.dist_info.
If index_map is not specified, each entries in the iterable alt_particles must represent events of the Predicate’s random argument at the same index in the predicate argument list. If the entry is ‘None’, the cached particle tensor of that predicate argument will be used instead.
Alternatively, one can specify a dictionary index_map mapping integer index to an integer index or a list of indices. The entry
alt_particles[i]will be taken as the particle tensor for theindex_map[i]th predicate argument. Ifindex_map[i]is a list of integers, then the particle tensor at this position will be interpreted as the concatenated/joint events of those predicate arguments whose indices are inindex_map[i]. Note that the entryalt_particles[i]can beNone, however in this caseindex_map[i]must refer to one predicate argument only. If there is any predicate argument that is not referenced by values of index_map , then the returning surrogate_log_prob will be marginalized over this predicate argument.Correspondingly, if
index_mapis specified, then all indices inalt_particlesmust appear as keys.- Parameters
param (torch.Tensor, optional) – The alternative parameter from which a surrogate distribution instance is to be instantiated and log prob being queried. Should have the same shape as the cached
self.batched_param.alt_particles (list of (torch.Tensor or None), or None) – The surrogate particles to be queried. If not None, each entry must either be None, so that the corresponding cached articles will be used instead, or a torch.Tensor, with a shape of length 2 and the last dimension size equal to the corresponding value in
self.rv_sizes. Defaults to Noneindex_map (dict, or None) – The optional index mapping. If specified, all applicable indices into alt_particles must appear as keys in this dict. The
ith entry in alt_particles will be taken as the surrogate particles for the predicate argument whose index isindex_map[i]ifindex_map[i]is an integer, or the joint surrogate particles for those arguments whose indices appear inindex_map[i]ifindex_map[i]is a list.
- Returns
The log probability tensor, of shape
(batch_shape + sample_shape), wherebatch_shapeis the batch shape of the Predicate’s knowledge, inferred from the shape of param, andsample_shapeis the list of sample sizes of the queried particles in the order given by alt_particles. In other words,sample_shape[i] == alt_particles[i].shape[0]. Ifalt_particles[i], it is the sample size of the cached particle tensor of the corresponding predicate argument.- Return type
torch.Tensor
- Raises
AssertionError – If self.batched_param is
None, meaning no cached parameters to instantiate distribution instance.AssertionError – If alt_particles contains
Nonebutself.particlesis also None, meaning no cached particles.AssertionError – If alt_param is specified but it has different shape than
self.batched_param.
-
event2torch_event(cat_particles)[source]¶ Translates joint particle event values from the Cognitive format to a format understandable by PyTorch distribution class.
- Parameters
cat_particles (torch.Tensor) – A tensor representing the list of concatenated particle events in Cognitive format. Its last dimension will be taken as the event dimension and should be equal to the sum of rv sizes in
self.rv_sizes, while all other dimensions will be taken as the sample dimensions.- Returns
A tensor representing a list of translated particle events from cat_particles. Its last dimension size depends on the PyTorch format representation of events, while the sizes of other dimensions are the same as cat_particles.
- Return type
torch.Tensor
Notes
The specific translation method may vary depending on the distribution class. Therefore, this method serves only as an API entry point where the specific translation procedure will be looked up in
self.dict_2torch_eventusing the registered distribution classself.dist_class. If no entry is found, then will assume no special translation is necessary and will return the input cat_particles as is.
-
event2cognitive_event(particles)[source]¶ Translates joint particle event values from the PyTorch distribution class format to Cognitive format.
- Parameters
particles (torch.Tensor) – A tensor representing the particle events in PyTorch-compatible format. Its last dimension will be taken as the event dimension, while all other dimensions will be taken as the sample dimensions.
- Returns
A concatenated tensor representing a list of translated particle events from cat_particles, where the events are concatenated along the last dimension, with size of each chunk in accordance with self.rv_sizes, and the sizes of all other dimensions are the same as particles.
- Return type
torch.Tensor
Notes
The specific translation method may vary depending on the distribution class. Therefore, this method serves only as an API entry point where the specific translation procedure will be looked up in
self.dict_2cognitive_eventusing the registered distribution classself.dist_class. If no entry is found, then will assume no special translation is necessary and will return the input cat_particles as is.
-
static
combinatorial_cat(particles)[source]¶ Helper static method that combinatorially concatenates the list of event particles specified by particles.
Returns the contained tensor directly if there is only one entry in particles.
- Parameters
particles (iterable of torch.Tensor) – The list of particles to be concatenated. Each element should be a tensor with a shape of length 2, where the first dimension is assumed the sample dimension, and second dimension assumed the event dimension.
- Returns
The combinatorially concatenated event particle tensor of shape:
[sample_size[0], ..., sample_size[m], event_size[0]+...+event_size[m]]
where
sample_size[i]is the sample size of theith particle, and similarly isevent_size[i]. Its total number of dimensions, i.e..dim(), is equal to the number of random variables plus 1.- Return type
torch.Tensor
-
static
combinatorial_decat(cat_particles, split_sizes)[source]¶ Helper static method that combinatorially de-concatenate the joint particles specified by cat_particles, with the event size of the particle tensors in each de-concatenated list given by split_sizes. This method implements the exact opposite operation of combinatorial_cat.
An exception will be raised if cat_particles cannot be properly de-concatenated, for instance if it is not previously a result produced by combinatorial_cat.
- Parameters
cat_particles (torch.Tensor) –
The concatenated particle tensor, of shape:
[sample_size[0], ..., sample_size[m], sum(split_sizes)]
where m is the number of variables to split/de-concatenate, to which the length of split_sizes should equal.
split_sizes (list of int) – A list of integers denoting the event size of each split variable in order.
- Returns
The tuple of de-concatenated particles. The i th entry has shape
[sample_size[i], split_sizes[i]].- Return type
tuple of torch.Tensor
- Raises
ValueError – If cat_particles was not a result from combinatorial_cat and cannot be properly de-concatenated.