Normalizing Flows on categorical data
A brief introduction to Categorical Normalizing Flows and GraphCNF
Ever since Generative Adversarial Networks have shown to generate realistic fake images, or GPT3 creating new, convincing text snippets, generative modeling has been a very popular research area, but also for future applications. One family of generative models, that allows both efficient sampling (i.e. generating new data) and exact density evaluation (i.e. determining the likelihood of a sample) is Normalizing Flows. Normalizing Flows model distributions by applying a sequence of invertible transformations mapping the input distribution to a known base distribution such as a factorized Gaussian. Successful applications include image and audio generation. However, a lot of real world data is categorical nature where variables represent a discrete value in a finite set. For instance, language is based on categorical data as words or characters represent different categories. Another popular example is molecule graphs where not only the nodes can be one out of many atom types, but also the edges (i.e. bonds between atoms) can have different categories. In this blog post, we will discuss what challenges there are for Normalizing Flows to be applied to categorical data, how Categorical Normalizing Flow offer a solution to those, and how we can extend this framework to a permutation-invariant graph generation model.
Normalizing Flows on categorical data
Normalizing Flows rely on the rule of change of variables which is naturally defined on continuous space. Attempts to apply Normalizing Flows directly on discrete data by discretizing their transformations suffer from optimization difficulties as biased gradient estimator need to be used. Besides, Discrete Normalizing Flows are not as flexible as their continuous counterparts. In the context of image modeling, a popular technique is to move the images into continuous space by dequantizing the discrete data, effectively adding a small amount of noise to transform a discrete point to a volume in continuous space. However, this works for images because their discrete values represent integers where 0 is closer to 1 than 128, and so on. In terms of categorical data, we cannot take such assumptions (for instance, is the letter a closer to m or r?). Instead, we would need a learnable encoding function from categorical to continuous space, which supports higher dimensions to represent more complex relations between categories. At the same time, no information should be lost when moving into continuous space, as we want to use all the benefits of Normalizing Flows.
Categorical Normalizing Flows offer a solution to this problem by using variational inference, but with a factorized decoder. In a VAE, an encoder maps the whole input to a smaller latent space, and a consecutive decoder tries to reconstruct the original input. In contrast, in Categorical Normalizing Flows, the encoder maps each categorical variable to a corresponding continuous variable in latent space, and the decoder is independently applied on those latent variables. A standard Normalizing Flow can then be applied on those continuous variables to map to a prior distribution like a Gaussian. This way, only the flow can model interactions between categorical variables, and the encoder-decoder framework is limited to providing suitable continuous representations to the flow. An overview of this architecture is given in the figure below, and details on this approach can be found in our paper.
In experiments on sets and language modeling, Categorical Normalizing Flows are able to match the performance of non-latent based autoregressive baselines while considerably outperforming VAE-based approaches with joint decoders. This shows the effectivity of using a factorized decoder, and focusing the modeling complexity into the flow.
GraphCNF: permutation-invariant graph modeling
Besides language, graph generation is another domain where categorical data plays an essential role. To make it more specific, we put our focus here on molecule generation, but note that the approach outlined below can also be applied to other graph generation tasks with categorical attributes, such as graph coloring.
Modeling and generating graphs is crucial in biology and chemistry for applications such as drug discovery, where molecule generation has emerged as a common benchmark. In a molecule graph, the nodes are atoms and he edges represent bonds between atoms, both represented by categorical features. The figure below shows a small molecule and its corresponding graph representation as an example.
Given a dataset of existing molecules, the goal is to learn a distribution of valid molecules as not all possible combinations of atoms and bonds are valid. We train the model via maximum likelihood on the existing data. Thereby, we need to model three parts of the graph: the node attributes (i.e. atom types), the edge attributes (i.e. bond category), and the adjacency matrix (i.e. whether two atoms have a bond at all or not). Furthermore, another challenge is that the nodes are unordered, meaning they have to be considered as a set and not as a list. Hence, any permutation of the nodes has to be assigned the exact same likelihood.
Using Categorical Normalizing Flows, we can implement this in a three-step approach as visualized in Figure 4, and call the approach GraphCNF. First, we map the node categories into continuous space using the approach of Categorical Normalizing Flows discussed above. On those latent variables, we apply a small sequence of coupling layers conditioned on the rest of the graph structure. In the second step, we map the edge attributes into continuous space, and apply another set of coupling layers. Finally, in the last step, we add the virtual edges, i.e. the edges that are not present in our graph. We add those as another edge category using Categorical Normalizing Flows, and apply a final set of coupling layers. Note that as those are applied on a fully connected graph now, this is the most expensive step of GraphCNF.
To ensure permutation invariance, all coupling layers use channel-wise masking, and all neural networks in the model are permutation equivariant graph neural networks (i.e. any permutation give the same output, just permuted). When applying this model to molecule generation and randomly sample 10k molecules from it, we see that more than 83% are actual valid molecules. Furthermore, most of them are unique and novel compared to the training dataset, showing that GraphCNF actually generalizes and does not just remember the training examples.
The most common failure case we observed was that the sampled graph was not connected, meaning that we actually generate two or more separate graphs. This can happen as in contrast to autoregressive models, it is much harder for parallel models to ensure that every node can be reached from any other node. When considering the biggest subgraph in those generation, we are able to push the validity percentage to 96% without enforcing any explicit rule of molecule graphs. This is a relatively large gap to the closest state-of-the-art autoregressive flow model at the time of writing, GraphAF, and shows the importance of the encoding in Normalizing Flows, and permutation invariance for graph generation in general.
In this blog post, we reviewed the key ideas behind Categorical Normalizing Flows and GraphCNF. As Normalizing Flows are naturally defined for continuous data, we can first the map discrete, categorical data into continuous space on which a flow can be applied. While doing so, factorizing the decoder helps by pushing all modeling complexity into the flow. GraphCNF relies on this principle to model graphs with categorical attributes such as molecules. In a three-step approach, it remains efficient yet provides a permutation invariant likelihood estimate of any graph. Future work can attempt to overcome the failure case of subgraph sampling in parallel graph models like GraphCNF, and scale it up to large graphs with more than 100 nodes.
To read up on all the details of Categorical Normalizing Flows and GraphCNF, please see our paper Categorical Normalizing Flows via Continuous Transformations.