<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
  <channel>
    <title>Daniele Grattarola</title>
    <description>Artificial intelligence scientist</description>
    <link>https://danielegrattarola.github.io/</link>
    <atom:link href="nathanrooy.github.io/feed.xml" rel="self" type="application/rss+xml"/>
    <pubDate>Fri, 07 Nov 2025 09:29:19 +0000</pubDate>
    <lastBuildDate>Fri, 07 Nov 2025 09:29:19 +0000</lastBuildDate>
    <generator>Jekyll v3.10.0</generator>
      
    
    <item>
        <title>My second interview on Machine Learning Street Talk</title>
        <description>&lt;div class=&quot;video-container&quot;&gt;
    &lt;iframe src=&quot;https://www.youtube-nocookie.com/embed/v5NysEyZkl0&quot; frameborder=&quot;0&quot; allowfullscreen=&quot;&quot;&gt;&lt;/iframe&gt;
&lt;/div&gt;

&lt;p&gt;I was featured for the second time on &lt;a href=&quot;https://www.youtube.com/channel/UCMLtBahI5DMrt0NPvDSoIRQ&quot;&gt;Machine Learning Street Talk&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;This interview was shot at NeurIPS 2023 last year, where I was presenting our work on 
&lt;a href=&quot;https://arxiv.org/abs/2205.15674&quot;&gt;generalized implicit neural representations&lt;/a&gt; from my time at EPFL.&lt;/p&gt;

&lt;p&gt;Cheers!&lt;/p&gt;
</description>
        <pubDate>Sat, 16 Dec 2023 00:00:00 +0000</pubDate>
        
        <link>/posts/2023-12-16/MLST-2.html</link>
          
        
            <category>update</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>My interview on Machine Learning Street Talk</title>
        <description>&lt;div class=&quot;video-container&quot;&gt;
    &lt;iframe src=&quot;https://www.youtube-nocookie.com/embed/MDt2e8XtUcA&quot; frameborder=&quot;0&quot; allowfullscreen=&quot;&quot;&gt;&lt;/iframe&gt;
&lt;/div&gt;

&lt;p&gt;I had the pleasure of being a guest on &lt;a href=&quot;https://www.youtube.com/channel/UCMLtBahI5DMrt0NPvDSoIRQ&quot;&gt;Machine Learning Street Talk&lt;/a&gt; to chat about cellular automata, emergence, life, the universe, and my own work on &lt;a href=&quot;https://danielegrattarola.github.io/posts/2021-11-08/graph-neural-cellular-automata.html&quot;&gt;graph neural cellular automata&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;I had a great time with Tim and Keith, they are doing an incredible work with the podcast and it’s really an honor to having been a part of it.&lt;/p&gt;

&lt;p&gt;Enjoy!&lt;/p&gt;

&lt;p&gt;&lt;sup style=&quot;font-size: 10px;&quot;&gt;P.S. I was so nervous and hyper-excited that I lost my own train of thought a couple of times, please be patient :D&lt;/sup&gt;&lt;/p&gt;
</description>
        <pubDate>Fri, 29 Apr 2022 00:00:00 +0000</pubDate>
        
        <link>/posts/2022-04-29/MLST.html</link>
          
        
            <category>update</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>Graph Neural Cellular Automata</title>
        <description>&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/fixed_target_animation.gif&quot; alt=&quot;Graph Neural Cellular Automata for morphogenesis&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href=&quot;https://en.wikipedia.org/wiki/Cellular_automaton&quot;&gt;Cellular automata&lt;/a&gt; (or CA for short) are a fascinating computational model. 
They consist of a lattice of stateful cells and a transition rule that updates the state of each cell as a function of its neighbourhood configuration. 
By applying this local rule synchronously over time, we see interesting dynamics emerge.&lt;/p&gt;

&lt;p&gt;For example, here is the transition table of &lt;a href=&quot;https://en.wikipedia.org/wiki/Rule_110&quot;&gt;Rule 110&lt;/a&gt; in a 1-dimensional binary CA:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/Rule110-rule.png&quot; alt=&quot;Rule 110, transition table&quot; class=&quot;threeq-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;And here is the corresponding evolution of the states starting from a random initialization (time goes downwards):&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/Rule110rand.png&quot; alt=&quot;Rule 110, evolution of the states&quot; class=&quot;quarter-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;By changing the rule, we get different dynamics, some of which can be extremely interesting. One example of this is the 2-dimensional &lt;a href=&quot;https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life&quot;&gt;Game of Life&lt;/a&gt;, with its complex patterns that replicate and move around the grid.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/glider_gun.gif&quot; alt=&quot;Gosper glider gun&quot; class=&quot;half-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;We can also bring this idea of locality to the extreme, by keeping it as the only requirement and making everything else more complicated.&lt;/p&gt;

&lt;p&gt;For example, if we make the states continuous and change the size of the neighbourhood, we get &lt;a href=&quot;https://arxiv.org/abs/1812.05433&quot;&gt;the mesmerizing Lenia CA&lt;/a&gt; with its &lt;em&gt;insanely&lt;/em&gt; life-like creatures that move around smoothly, reproduce, and even organize themselves into higher-order organisms.&lt;/p&gt;

&lt;div class=&quot;video-container&quot;&gt;
    &lt;iframe src=&quot;https://www.youtube-nocookie.com/embed/iE46jKYcI4Y&quot; frameborder=&quot;0&quot; allowfullscreen=&quot;&quot;&gt;&lt;/iframe&gt;
&lt;/div&gt;

&lt;p&gt;By this principle, we can also derive an even more general version of CA, in which the neighbourhoods of the cells no longer have a fixed shape and size. Instead, the cells of the CA are organized in an arbitrary graph.&lt;/p&gt;

&lt;p&gt;Note that the central idea of locality that characterizes CA does not change at all: we’re just extending it to account for these more general neighbourhoods.&lt;/p&gt;

&lt;p&gt;The super-general CA are usually called &lt;strong&gt;Graph Cellular Automata (GCA)&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/gca_transition.png&quot; alt=&quot;Example of GCA transition&quot; class=&quot;half-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The general form of GCA transition rules is a map from a cell and its neighbourhood to the next state, and we can also make it &lt;strong&gt;anisotropic&lt;/strong&gt; by introducing edge attributes that specify a relation between the cell and each neighbour.&lt;/p&gt;

&lt;h2 id=&quot;learning-ca-rules&quot;&gt;Learning CA rules&lt;/h2&gt;

&lt;p&gt;The world of CA is fascinating, but unfortunately, they are almost always considered simply pretty things.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;But can they be also useful? Can we design a rule to solve an interesting problem using the decentralized computation of CA?&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The answer is yes, but manually designing such a rule may be hard. However, being AI scientists, we can try to learn the rule.&lt;/p&gt;

&lt;p&gt;This is not a new idea.&lt;/p&gt;

&lt;p&gt;We can go back to NeurIPS 1992 to find a seminal work on &lt;a href=&quot;https://papers.nips.cc/paper/1992/hash/d6c651ddcd97183b2e40bc464231c962-Abstract.html&quot;&gt;learning CA rules with neural networks&lt;/a&gt; (they use convolutional neural networks, although back then they were called “sum-product networks with shared weights”).&lt;/p&gt;

&lt;p&gt;Since then, we’ve seen other approaches to learn CA rules, like these papers using &lt;a href=&quot;https://mobile.aau.at/~welmenre/papers/elmenreich-iwsos2011.pdf&quot;&gt;genetic algorithms&lt;/a&gt; or &lt;a href=&quot;https://ieeexplore.ieee.org/abstract/document/8004527&quot;&gt;compositional pattern-producing networks&lt;/a&gt; to find rules that lead to a desired configuration of states, a task called &lt;strong&gt;morphogenesis&lt;/strong&gt;.&lt;sup style=&quot;font-size: 10px;&quot;&gt; &lt;a href=&quot;https://sci-hub.se/&quot;&gt;Papers not on arXiv, sorry&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;More recently, convolutional networks have been shown to be extremely versatile in learning CA rules. 
&lt;a href=&quot;https://arxiv.org/abs/1809.02942&quot;&gt;This work by William Gilpin&lt;/a&gt;, for example, shows that we can implement any desired transition rule with CNNs by smartly setting their weights.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/planarian.jpg&quot; alt=&quot;A planarian flatworm&quot; class=&quot;half-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;CNNs have also been used for morphogenesis. Inspired by the regenerative abilities of the flatworm (pictured above), &lt;a href=&quot;https://distill.pub/2020/growing-ca/&quot;&gt;in this visually-striking paper&lt;/a&gt; they train a CNN to grow into a desired image and to regenerate the image if it is perturbed.&lt;/p&gt;

&lt;h2 id=&quot;learning-gca-rules&quot;&gt;Learning GCA rules&lt;/h2&gt;

&lt;p&gt;So, can we do something similar in the more general setting of GCA?&lt;/p&gt;

&lt;p&gt;Well, let’s start with the model. 
Similar to how CNNs are the natural family of models to implement typical grid-based CA rules, the more general family of graph neural networks is the natural choice for GCA.&lt;/p&gt;

&lt;p&gt;We call this setting the &lt;strong&gt;Graph Neural Cellular Automata (GNCA)&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/thumbnail_cut.png&quot; alt=&quot;Graph Neural Cellular Automata&quot; class=&quot;threeq-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;We propose an architecture composed of a pre-preprocessing MLP, a message-passing layer, and a post-processing MLP, which we use as  transition function.&lt;/p&gt;

&lt;p&gt;This model is universal to represent GCA transition rules. We can prove this by making an argument similar to the one for CNNs that I mentioned above.&lt;/p&gt;

&lt;p&gt;I won’t go into the specific details here, but in short, we need to implement two operations:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;One-hot encoding of the states;&lt;/li&gt;
  &lt;li&gt;Pattern-matching for the desired rule.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The first two blocks in our GNCA are more than enough to achieve this. 
The pre-processing MLP can compute the one-hot encoding, and by using edge attributes and &lt;a href=&quot;https://arxiv.org/abs/1704.02901&quot;&gt;edge-conditioned convolutions&lt;/a&gt; we can implement pattern matching easily.&lt;/p&gt;

&lt;h2 id=&quot;experiments&quot;&gt;Experiments&lt;/h2&gt;

&lt;p&gt;However, regardless of what the theory says, we want to know whether we can learn a rule in practice. Let’s try a few experiments.&lt;/p&gt;

&lt;h3 id=&quot;voronoi-gca&quot;&gt;Voronoi GCA&lt;/h3&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/voronoi.png&quot; alt=&quot;Voronoi GCA&quot; class=&quot;half-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;We can start from the simplest possible binary GCA, inspired by the 1992 NeurIPS paper I mentioned before. The difference is that our CA cells are the &lt;a href=&quot;https://en.wikipedia.org/wiki/Voronoi_diagram&quot;&gt;Voronoi tasselletion&lt;/a&gt; of some random points. 
Alternatively, you can think of this GCA as being defined on the &lt;a href=&quot;https://en.wikipedia.org/wiki/Delaunay_triangulation&quot;&gt;Delaunay triangulation&lt;/a&gt; of the points.&lt;/p&gt;

&lt;p&gt;We use an &lt;a href=&quot;https://en.wikipedia.org/wiki/Life-like_cellular_automaton.&quot;&gt;outer-totalistic rule&lt;/a&gt; that swaps the state of a cell if the density of its alive neighbours exceeds a certain threshold, not too different from the Game of Life.&lt;/p&gt;

&lt;p&gt;We try to see if our model can learn this kind of transition rule. In particular, we can train the model to approximate the 1-step dynamics in a supervised way, given that we know the true transition rule.&lt;/p&gt;

&lt;div style=&quot;text-align: center&quot;&gt;
&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/learn_gca_loss_v_epoch.svg&quot; width=&quot;30%&quot; style=&quot;display: inline-block; margin:auto;&quot; /&gt;&amp;nbsp;
&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/learn_gca_acc_v_epoch.svg&quot; width=&quot;30%&quot; style=&quot;display: inline-block; margin:auto;&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;The results are encouraging. We see that the GNCA achieves 100% accuracy with no trouble and, if we let it evolve autonomously, it does not diverge from the real trajectory.&lt;/p&gt;

&lt;h3 id=&quot;boids&quot;&gt;Boids&lt;/h3&gt;
&lt;div style=&quot;text-align: center&quot;&gt;
&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/alignment.png&quot; width=&quot;30%&quot; style=&quot;display: inline-block; margin:auto;&quot; /&gt;&amp;nbsp;
&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/cohesion.png&quot; width=&quot;30%&quot; style=&quot;display: inline-block; margin:auto;&quot; /&gt;&amp;nbsp;
&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/separation.png&quot; width=&quot;30%&quot; style=&quot;display: inline-block; margin:auto;&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;For our second experiment, we keep a similar setting but make the target GCA much more complicated. 
We consider the &lt;a href=&quot;https://en.wikipedia.org/wiki/Boids&quot;&gt;Boids&lt;/a&gt; algorithm, an agent-based model designed to simulate the flocking of birds. This can be still seen as a kind of GCA because the state of each bird (its position and velocity) is updated only locally as a function of its closest neighbours.
However, this means that the states of the GCA are continuous and multi-dimensional, and also that the graph changes over time.&lt;/p&gt;

&lt;p&gt;Again, we can train the GNCA on the 1-step dynamics. We see that, although it’s hard to approximate the exact behaviour, we get very close to the true system. 
The GNCA (yellow) can form the same kind of flocks as the true system (purple), even if their trajectories diverge.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/boids_animation.gif&quot; alt=&quot;Boids GCA and trained GNCA&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;h3 id=&quot;morphogenesis&quot;&gt;Morphogenesis&lt;/h3&gt;

&lt;p&gt;The final experiment is also the most interesting, and the one where we actually design a rule. 
Like previously in the literature, here too we focus on morphogenesis. Our task is to find a GNCA rule that, starting from a given initial condition, converges to a desired point cloud (like a bunny) where the connectivity of the cells has a geometrical/spatial meaning.&lt;/p&gt;

&lt;p&gt;In this case, we don’t know the true rule, so we must train the model differently, by teaching it to arrive at the target state when evolving autonomously.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/gnca_training.png&quot; alt=&quot;Training scheme for GNCA&quot; class=&quot;threeq-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;To do so, we let the model evolve for a given number of steps, then we compute the loss from the target, and we update the weights with backpropagation through time. 
To stabilise training, and to ensure that the target state becomes a stable attractor of the GNCA, we use a cache. This is a kind of replay memory from which we sample the initial conditions, so that we can reuse the states explored by the GNCA during training.
Crucially, this teaches the model to remain at the target state when starting from the target state.&lt;/p&gt;

&lt;p&gt;And the results are pretty amazing… have you seen the gif at the &lt;a href=&quot;#&quot;&gt;top of the post&lt;/a&gt;? Let’s unroll the first few frames here.&lt;/p&gt;

&lt;p&gt;A 2-dimensional grid:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Grid_10-20/evolution.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;A bunny:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Bunny_10-20/evolution.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The &lt;a href=&quot;https://pygsp.readthedocs.io/en/stable/&quot;&gt;PyGSP&lt;/a&gt; logo:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Logo_20/evolution.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;We see that the GNCA has no trouble in finding a stable rule that converges quickly at the target and then remains there.&lt;/p&gt;

&lt;p&gt;Even for complex and seemingly random graphs, like the Minnesota road network, the GNCA can learn a rule that quickly and stably converges to the target:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Minnesota_20/anim.gif&quot; alt=&quot;&quot; class=&quot;third-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;However, this is not the full story. Sometimes, instead of converging, the GNCA learns to remain in an orbit around the target state, giving us these oscillating point clouds.&lt;/p&gt;

&lt;p&gt;Grid:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Grid_10/evolution.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Bunny:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Bunny_10/evolution.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Logo:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-11-08/clouds/Logo_10/evolution.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;h2 id=&quot;now-what&quot;&gt;Now what?&lt;/h2&gt;

&lt;p&gt;So, where do we go from here?&lt;/p&gt;

&lt;p&gt;We have seen that GNCA can reach global coherence through local computation, which is not that different from what we do in graph representation learning. In fact, &lt;a href=&quot;https://www.researchgate.net/profile/Franco_Scarselli/publication/4202380_A_new_model_for_earning_in_raph_domains/links/0c9605188cd580504f000000.pdf&quot;&gt;the first GNN paper&lt;/a&gt;, back in 2005, already contained this idea.&lt;/p&gt;

&lt;p&gt;But moving forward, it’s easy to see that the idea of emergent computation on graphs could apply to many scenarios, including swarm optimization and control, modelling epidemiological transmission, and it could even improve our understanding of complex biological systems, like the brain.&lt;/p&gt;

&lt;p&gt;GNCA enable the design of GCA transition rules, unlocking the power of decentralised and emergent computation to solve real-world problems.&lt;/p&gt;

&lt;p&gt;The code for the paper is available &lt;a href=&quot;https://github.com/danielegrattarola/GNCA&quot;&gt;on Github&lt;/a&gt; and feel free to reach out via email if you have any questions or comments.&lt;/p&gt;

&lt;h2 id=&quot;read-more&quot;&gt;Read more&lt;/h2&gt;

&lt;p&gt;This blog post is the short version of our NeurIPS 2021 paper:&lt;/p&gt;

&lt;p&gt;&lt;a href=&quot;https://arxiv.org/abs/2110.14237&quot;&gt;Learning Graph Cellular Automata&lt;/a&gt;&lt;br /&gt;
&lt;em&gt;D. Grattarola, L. Livi, C. Alippi&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;You can cite the paper as follows:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;@inproceedings{grattarola2021learning,
  title={Learning Graph Cellular Automata},
  author={Grattarola, Daniele and Livi, Lorenzo and Alippi, Cesare},
  booktitle={Neural Information Processing Systems},
  year={2021}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
</description>
        <pubDate>Mon, 08 Nov 2021 00:00:00 +0000</pubDate>
        
        <link>/posts/2021-11-08/graph-neural-cellular-automata.html</link>
          
        
            <category>GNN</category>
        
            <category>cellular-automata</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>A practical introduction to GNNs - Part 2</title>
        <description>&lt;p&gt;&lt;em&gt;This is Part 2 of an introductory lecture on graph neural networks that I gave for the “Graph Deep Learning” course at the University of Lugano.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;After a practical introduction to GNNs in &lt;a href=&quot;https://danielegrattarola.github.io/posts/2021-03-03/gnn-lecture-part-1.html&quot;&gt;Part 1&lt;/a&gt;, here I show how we can formulate GNNs in a much more flexible way using the idea of message passing.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;First, I introduce message passing. Then, I show how to implement message-passing networks in Jax/pseudocode using a paradigm called “gather-scatter”. Finally, I show how to implement a couple of more advanced GNN models.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href=&quot;https://danielegrattarola.github.io/files/talks/2021-03-01-USI_GDL_GNNs.pdf&quot;&gt;The full slide deck is available here&lt;/a&gt;.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;In &lt;a href=&quot;https://danielegrattarola.github.io/posts/2021-03-03/gnn-lecture-part-1.html&quot;&gt;Part 1&lt;/a&gt; of this series we constructed our first kind of GNN by replicating the behavior of conventional CNNs on data supported by graphs.&lt;/p&gt;

&lt;p&gt;The core building block that we used in our simple GNNs looked like this:&lt;/p&gt;

\[\mathbf{X}&apos; = \mathbf{R}\mathbf{X}\mathbf{\Theta}\]

&lt;p&gt;which, as we saw, has two effects:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;All node attributes \(\mathbf{X}\) are transformed using the learnable matrix \(\mathbf{\Theta}\);&lt;/li&gt;
  &lt;li&gt;The attribute of each node gets replaced with a weighted sum of its neighbors via the reference operator \(\mathbf{R}\) (also, sometimes we can include the node itself in the sum);&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;By combining these two ideas we were able to get a very good approximation of a CNN for graphs.&lt;/p&gt;

&lt;p&gt;In this part of the lecture, we will take these two ideas and describe them a little more formally, distilling the essential role that they have in a GNN.&lt;/p&gt;

&lt;p&gt;We will see a general framework called &lt;strong&gt;message passing&lt;/strong&gt;, which will allow us to describe more complex GNNs than those we have seen so far.&lt;/p&gt;

&lt;h2 id=&quot;message-passing-networks&quot;&gt;Message Passing Networks&lt;/h2&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-14.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The idea of message passing networks was introduced in a paper by &lt;a href=&quot;&quot;&gt;Gilmer et al.&lt;/a&gt; in 2017 and it essentially boils GNN layers down to three main steps:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Every node in the graph computes a &lt;strong&gt;message&lt;/strong&gt; for each of its neighbors. Messages are a function of the node, the neighbor, and the edge between them.&lt;/li&gt;
  &lt;li&gt;Messages are sent, and every node &lt;strong&gt;aggregates&lt;/strong&gt; the messages it receives, using a permutation-invariant function (i.e., it doesn’t matter in which order the messages are received). This function is usually a sum or an average.&lt;/li&gt;
  &lt;li&gt;After receiving the messages, each node &lt;strong&gt;updates&lt;/strong&gt; its attributes as a function of its current attributes and the aggregated messages.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This procedure happens synchronously for all nodes in the graph, so that at each message passing step all nodes are updated.&lt;/p&gt;

&lt;p&gt;If we look back at our super-simple GNN formulation \(\mathbf{X}&apos; = \mathbf{R}\mathbf{X}\mathbf{\Theta}\), we can easily see the three message-passing steps:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;Message&lt;/strong&gt; - Each node \(i\) will receive the same kind of message \(\mathbf{\Theta}^\top\mathbf{x}_j\) from all its neighbors \(j \in \mathcal{N}(i)\).&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Aggregate&lt;/strong&gt; - Messages are aggregated with a weighted sum, where weights are defined by the reference operator \(\mathbf{R}\).&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Update&lt;/strong&gt; - Each node simply replaces its attributes with the aggregated messages. &lt;br /&gt;
If \(\mathbf{R}\) has a non-zero diagonal, then each node also computes a message “from itself to itself” using \(\mathbf{\Theta}\).&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-15.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Message passing is usually formalized with the equation in the slide above.&lt;/p&gt;

&lt;p&gt;While it may look complicated at first, the formula simply describes the three steps that we just saw, and if we wanted to write it in Python it would look something like this:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# For every node in the graph
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_nodes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Compute messages from neighbors
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Aggregate messages
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;aggregated&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;aggregate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Update node attributes
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;aggregated&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;As long as &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;message&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;aggregate&lt;/code&gt;, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;update&lt;/code&gt; are differentiable functions, we can train a neural network to transforms its inputs like this. &lt;br /&gt;
In fact, this framework is so general that virtually all libraries that implement GNNs are based on it.&lt;/p&gt;

&lt;p&gt;For example, &lt;a href=&quot;https://graphneural.network&quot;&gt;Spektral&lt;/a&gt;, &lt;a href=&quot;https://pytorch-geometric.readthedocs.io/&quot;&gt;Pytorch Geometric&lt;/a&gt;, and &lt;a href=&quot;https://www.dgl.ai/&quot;&gt;DGL&lt;/a&gt; all have a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MessagePassing&lt;/code&gt; class that looks like this:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;MessagePassing&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Layer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;call&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# This is the actual message-passing step
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;propagate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;propagate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Compute messages
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;msg_kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Aggregate messages
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;aggregated&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;aggregate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;agg_kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Update self
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;aggregated&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;upd_kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;aggregate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;aggregated&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;gather-scatter&quot;&gt;Gather-Scatter&lt;/h2&gt;

&lt;p&gt;The cool thing about message passing is that it lets us define the operations that our GNN computes, without necessarily resorting to matrix multiplication.&lt;/p&gt;

&lt;p&gt;In fact, the only thing that we specify is how the GNN acts on a generic node \(i\) as a function of its generic neighbors \(j \in \mathcal{N}(i)\).&lt;/p&gt;

&lt;p&gt;For instance, let’s say that we wanted to implement the “Edge Convolution” operator from the paper &lt;a href=&quot;https://arxiv.org/abs/1801.07829&quot;&gt;“Dynamic Graph CNN for Learning on Point Clouds”&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;In the message-passing framework, we write its effect as:&lt;/p&gt;

\[\mathbf{x}_i&apos; = \sum\limits_{j \in \mathcal{N}(i)} \textrm{MLP}\big( \mathbf{x}_i \| \mathbf{x}_j - \mathbf{x}_i \big)\]

&lt;p&gt;If we wanted to implement this as a matrix multiplication like we have done so far, we would have some troubles, because GNNs of the form \(\mathbf{R}\mathbf{X}\mathbf{\Theta}\) assume that every node sends the same message to each of its neighbors. Here, instead, messages are a function of edges \(j \rightarrow i\).&lt;/p&gt;

&lt;p&gt;In fact, this is a limitation of every GNN with edge-dependent messages.&lt;/p&gt;

&lt;p&gt;We could still implement our Edge Convolution using broadcasting operations, but it would not be efficient at all. Here’s one way we could do it:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Node attributes of shape [n, f]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Adjacency matrix of shape [n, n]
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Compute all pairwise differences between nodes
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_diff&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, n, f)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Repeat the nodes so that we can concatenate them to the differences
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_repeat&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;repeat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, n, f)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Concatenate the attributes so that, for each edge, we have x_i || (x_i - x_j)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_all&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_repeat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_diff&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, n, 2 * f)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Give x_i || (x_i - x_j) as input to an MLP
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, n, channels)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Broadcast-multiply `a` to keep only &quot;real&quot; messages
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[...,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, n, channels)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Sum along the &quot;neighbors&quot; axis.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, channels)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Note that we had to compute messages for &lt;strong&gt;all possible edges&lt;/strong&gt; and then simply multiply some of the messages by zero using &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;a&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;This is not ideal&lt;/strong&gt;, because it cost us \(O(N^2)\) to do something that should have a cost linear in the number of edges (this is a big difference when working with real-world graphs, which are usually very sparse).&lt;/p&gt;

&lt;p&gt;A much better way to achieve our goal is to exploit the advanced indexing features offered by all libraries for tensor manipulation, using a technique called &lt;strong&gt;gather-scatter&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;The gather-scatter technique requires us to think a bit differently, using node indices to access &lt;strong&gt;only the nodes that we are interested in&lt;/strong&gt;, in a sparse way.&lt;/p&gt;

&lt;p&gt;This is much easier done than said, so let’s see an example.&lt;/p&gt;

&lt;p&gt;Let us consider an adjacency matrix &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;a&lt;/code&gt;:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
     &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
     &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;This matrix is equivalently represented in the sparse COOrdinate format:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;row&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Nodes that are sending a message
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;col&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Nodes that are receiving a message
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;which simply tells us the indices of the non-zero entries of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;a&lt;/code&gt; (we usually also have an extra array that tells us the actual values of the entries, but we won’t need it for now).&lt;/p&gt;

&lt;p&gt;If we consider all edges \(j \rightarrow i\), then the attributes of all nodes that are &lt;em&gt;sending&lt;/em&gt; a message can be retrieved with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x[row]&lt;/code&gt;. 
Similarly, the attributes of nodes that are receiving a message can be retrieved with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x[col]&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;This is called &lt;strong&gt;gathering&lt;/strong&gt; the nodes.&lt;/p&gt;

&lt;p&gt;In our case, if we want to take the difference of the nodes at the opposite side of an edge, we can simply do &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x[row] - x[col]&lt;/code&gt;. 
Instead of computing the difference &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x[j] - x[i]&lt;/code&gt; for all possible pairs &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;j, i&lt;/code&gt;, like we did before, now we only compute the differences that we are really interested in.&lt;/p&gt;

&lt;p&gt;All these operations will give us matrices that have as many rows as there are edges. So for instance, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x[row]&lt;/code&gt; will look like this:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
 &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
 &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
 &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
 &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n_edges, f)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The other half of this story tells us how to aggregate the messages after we have gathered them. We call this &lt;strong&gt;scattering&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;For all nodes \(i\), we want to aggregate all messages that are being sent via edges that have index \(i\) on the &lt;strong&gt;receiving&lt;/strong&gt; end, i.e., all edges of the form \(j \rightarrow i\).
For instance, in the small example above we know that node 2 will receive a message from nodes 0 and 1.&lt;/p&gt;

&lt;p&gt;We can do this using some special operations available more or less in all libraries for tensor manipulation:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;In TensorFlow, we have &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tf.math.segment_[sum|prod|mean|max|min]&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;For PyTorch, we have the &lt;a href=&quot;https://github.com/rusty1s/pytorch_scatter&quot;&gt;Torch Scatter&lt;/a&gt; library by Matthias Fey.&lt;/li&gt;
  &lt;li&gt;In Jax, we only have &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.ops.segment_sum&lt;/code&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;These operations apply a reduction to “segments” of a tensor, where the segments are defined by integer indices. Something like this:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Example: segment sum
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;# A tensor that we want to reduce
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segments&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Segment indices (we have 4 segments)
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segments&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# One result for each segment
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segments&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;             &lt;span class=&quot;c1&quot;&gt;# It could also be a product, max, etc...
&lt;/span&gt;
&lt;span class=&quot;o&quot;&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; 
&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;13&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;So for instance, if we want to sum all messages based on their intended recipient, we can do:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;aggregated&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ops&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segment_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;col&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; 
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now we can put all of this together to create our Edge Convolution layer with a gather-scatter implementation:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;scipy&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Node attributes of shape [n, f]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Adjacency matrix of shape [n, n]
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Get indices of the non-zero entries of the adjacency matrix
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;senders&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scipy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sparse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;find&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Calculate difference of nodes for each edge j -&amp;gt; i
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_diff&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;senders&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n_edges, f)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Concatenate x_i with (x_i - x_j) for each edge j -&amp;gt; i
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_all&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_diff&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n_edges, 2 * f)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Give x_i || (x_i - x_j) as input to an MLP
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n_edges, channels)
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Aggregate all messages according to their intended receiver
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ops&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segment_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape: (n, channels)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Wrap this up in a layer and we’re done!&lt;/p&gt;

&lt;p&gt;Here’s what it looks like &lt;a href=&quot;https://github.com/danielegrattarola/spektral/blob/master/spektral/layers/convolutional/edge_conv.py&quot;&gt;in Spektral&lt;/a&gt; and &lt;a href=&quot;https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/edge_conv.html#EdgeConv&quot;&gt;in Pytorch Geometric&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;methods&quot;&gt;Methods&lt;/h2&gt;

&lt;p&gt;We have now moved past the simple GNNs based on a multiplication by the reference operator and with edge-independent messages that we saw in the first part of this series. Let’s look at some more advanced methods!&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-17.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;For instance, the popular &lt;a href=&quot;https://arxiv.org/abs/1710.10903&quot;&gt;Graph Attention Networks&lt;/a&gt; by Veličković et al. can be implemented as a message-passing network using gather-scatter:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Transform node attributes with a dense layer
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Concatenate attributes of receivers/senders
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_cat&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;senders&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Compute attention logits with a dense layer (output dim = 1, LeakyReLU)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_cat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Apply softmax only to the logits in the same segment, as defined by receivers
# i.e., normalize the scores only among the neighbors of each node.
# Note that segment_softmax does **not** reduce the tensor: `coef` has the same 
# shape as `logits`.
# This function is available in Spektral and PyG.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;coef&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;segment_softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Now we aggregate with a weighted sum (weights given by coef)
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ops&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segment_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;coef&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;senders&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-18.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Easily enough, we can also define a message-passing network that includes edge attributes in the computation of messages. One of my favorite models is the &lt;a href=&quot;https://arxiv.org/abs/1704.02901&quot;&gt;Edge-Conditioned Convolution&lt;/a&gt; by Simonovsky &amp;amp; Komodakis, of which I’ve summarized the math in the slide above.&lt;/p&gt;

&lt;p&gt;To implement it with gather-scatter we can do:&lt;/p&gt;

&lt;div class=&quot;language-py highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Use a Filter-Generating Network to create a feature of size (f * f_,) for each 
# edge
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;filter_generating_netrwok&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Reshape the weights so that we have a matrix of shape (f, f_) for each edge
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Multiply the node attribute of each neighbor by the associated edge-dependent
# kernel. We can use einsum to do this efficiently.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ab,abc-&amp;gt;ac&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;senders&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Aggergate with a sum
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ops&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segment_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;messages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;receivers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Once you get the hang of it, building GNNs becomes so intuitive that you’ll never want to go back to the matrix-multiplication-based implementations. 
Although, sometimes, it makes sense to do it. But that’s a story for another day.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;With the first two parts of this blog series in your arsenal, you should be able to go a long way in the world of GNNs.&lt;/p&gt;

&lt;p&gt;The next and final part will take a more historical and mathematical journey in the world of GNNs. We’ll cover spectral graph theory and how we can define the operation of &lt;strong&gt;convolution&lt;/strong&gt; on graphs.&lt;/p&gt;

&lt;p&gt;I have left this for last because it is not &lt;em&gt;essential&lt;/em&gt; to understand and use GNNs in practice, although I think that understanding the historical perspective that led to the creation of modern GNNs is very important.&lt;/p&gt;

&lt;p&gt;Stay tuned.&lt;/p&gt;
</description>
        <pubDate>Fri, 12 Mar 2021 00:00:00 +0000</pubDate>
        
        <link>/posts/2021-03-12/gnn-lecture-part-2.html</link>
          
        
            <category>GNN</category>
        
            <category>lecture</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>A practical introduction to GNNs - Part 1</title>
        <description>&lt;p&gt;&lt;em&gt;This is Part 1 of an introductory lecture on graph neural networks that I gave for the “Graph Deep Learning” course at the University of Lugano.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;At this point in the course, the students had already seen a high-level overview of GNNs and some of their applications. My goal was to give them a practical understanding of GNNs.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;Here I show that, starting from traditional CNNs and changing a few underlying assumptions, we can create a neural network that processes graphs.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href=&quot;https://danielegrattarola.github.io/files/talks/2021-03-01-USI_GDL_GNNs.pdf&quot;&gt;The full slide deck is available here&lt;/a&gt;.&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;My goal for this lecture is to show you how Graph Neural Networks (GNNs) can be obtained as a generalization of traditional convolutional neural networks (CNNs), where instead of images we have graphs as input.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;But what does it mean that a CNN can be made more general? Why are graphs a more general version of images?&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;We know that CNNs are designed to process data that describe the world through a collection of discrete data points: time steps in a time series, pixels in an image, pixels in a video, etc.&lt;/p&gt;

&lt;p&gt;However, one aspect of images and time series that we rarely (if at all) consider explicitly is the fact that the collection of data points alone is not enough. The order in which pixels are arranged to form an image is possibly more important than the pixels themselves. &lt;br /&gt;
An image can be in color or in grayscale but, as long as the arrangement of pixels is the same, we’ll likely be able to recognize the image for what it is.&lt;/p&gt;

&lt;p&gt;We could go as far as saying that an image is only an image because its pixels are arranged in a particular structure: pixels that represent points close in space or time should also be next to each other in the collection. Change this structure, and the image loses meaning.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-4.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;CNNs are designed to take this &lt;strong&gt;locality&lt;/strong&gt; into account. They are designed to transform the value of each pixel, not as a function of the whole image (like a MLP would do), but as a function of the pixel’s immediate surroundings. Its neighbors.&lt;/p&gt;

&lt;p&gt;Since &lt;strong&gt;locality is a kind of relation&lt;/strong&gt; between pixels, it is natural to represent the underlying structure of an image using a graph.
And, by requiring that each pixel is related only to the few other pixels that are closer to it, our graph will be a &lt;strong&gt;regular grid&lt;/strong&gt;. Every pixel has 8 neighbors (give or take boundary conditions), and the CNN uses this fact to compute a localized transformation.&lt;/p&gt;

&lt;p&gt;You can also interpret it the other way around. The kind of processing that the CNN does means that the transformation of each pixel will only depend on the few pixels that fall under the convolutional kernel. We can say that the grid structure emerges as a consequence of the CNN’s inductive bias.&lt;/p&gt;

&lt;p&gt;In any case, the important thing to note is that the grid structure does not depend on the specific pixel values. &lt;strong&gt;We separate the values of the data points from the underlying structure that supports them.&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-5.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;With this perspective in mind, the question of “how to make CNNs work on graphs” becomes:&lt;/p&gt;

&lt;p&gt;&lt;em&gt;Can we create a neural network in which the structure of the data is no longer a regular grid, but an arbitrary graph that we give as input?&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;In other words, since we know that data and structure are different things, can we change the structure as we please?&lt;/p&gt;

&lt;p&gt;The only thing that we require is that the CNN does the same kind of local processing as it did for the regular grid: transform each node as a function of its neighbors.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-6.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;If we look at what this request entails, we immediately see some problems:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;
    &lt;p&gt;In the “regular grid” case, the learnable kernel of the CNN is compact and has a fixed size: one set of weights for each possible neighbor of a pixel, plus one set for the pixel itself. In other words, the kernel is supported by a smaller grid. 
We can’t do that easily for an arbitrary graph. Since nodes can have a variable number of neighbors, we also need a kernel that varies in size. Possible, but not straightforward.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;In the regular grids processed by CNNs, we have an implicit notion of directionality. We always know where up, down, left and right are. When we move to an arbitrary graph, we might not be able to define a direction. Direction is, in essence, a kind of attribute that we assign to the edges, but in our case we also allow graphs that have no edge attributes at all. Ask yourself: do you have an up-and-to-the-left follower on Twitter?&lt;/p&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;To go from CNN to GNN we need to solve these problems.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-7.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;[I recall notation here because the students had already seen most of these things anyway, but the concept of “reference operator” gave me a nice segue into the next slide.]&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;All this talking about edge attributes also made me remember that now is a good time to do a notation check. Briefly:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;
    &lt;p&gt;We define a graph as a collection of nodes and edges.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;Nodes can have vector attributes, which we represent in a neatly packed matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\) (sometimes called a &lt;em&gt;graph signal&lt;/em&gt;).
Same thing for edges, with attributes \(\mathbf{e}_{ij} \in \mathbb{R}^S\) for edge i-j.&lt;/p&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Then there are the characteristic matrices of a graph:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;
    &lt;p&gt;The adjacency matrix \(\mathbf{A}\) is binary and has a 1 in position i-j if there exists an edge from node i to node j. All entries are 0 otherwise.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;The degree matrix \(\mathbf{D}\) counts the number of neighbors of each node. It’s a diagonal matrix so that the degree of node i is in position i-i.&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;The Laplacian, which we will use a lot later, is defined as \(\mathbf{L} = \mathbf{D} - \mathbf{A}\).&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;Finally, the normalized adjacency matrix is \(\mathbf{A}_n = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\).&lt;/p&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Note that \(\mathbf{A}\), \(\mathbf{L}\), and \(\mathbf{A}_n\) &lt;strong&gt;share the same sparsity pattern, if you don’t count the diagonal&lt;/strong&gt;. Their only non-zero entries are in position i-j only if edge i-j exists.&lt;/p&gt;

&lt;p&gt;Since we’re more interested in this specific property than in the actual values that are stored in the non-zero entries, let’s give it a name: we call any matrix that has the same sparsity pattern of \(\mathbf{A}\) a &lt;strong&gt;reference operator&lt;/strong&gt; (sometimes a &lt;em&gt;structure&lt;/em&gt; operator, sometimes a &lt;em&gt;graph shift&lt;/em&gt; operator, it’s not important).&lt;/p&gt;

&lt;p&gt;Also note: so far we are considering graphs with undirected edges. This means that all reference operators will be symmetric (if edge i-j exists, then edge j-i exists).&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-8.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Reference operators are nice.&lt;/p&gt;

&lt;p&gt;First of all, they are operators. You multiply them by a graph signal and you get a new graph signal in return. Let’s look at the “shape” of the multiplication: N-by-N times N-by-F equals N-by-F. Checks out.&lt;/p&gt;

&lt;p&gt;But not only that. By their own definition, multiplying a reference operator by a graph signal will compute a weighted sum of each node’s neighborhood. Let’s expand the matrix multiplication from the slide above to see what happens to node 1 when we apply a reference operator.&lt;/p&gt;

&lt;p&gt;All values \(\mathbf{r}_{ij}\) that are not associated with an edge are 0, so we have:&lt;/p&gt;

\[(\mathbf{R}\mathbf{X})_1 = \mathbf{r}_{12}\cdot\mathbf{x}_2 + \mathbf{r}_{13}\cdot\mathbf{x}_3 + \mathbf{r}_{14}\cdot\mathbf{x}_4\]

&lt;p&gt;Look at that: with a simple matrix multiplication we can now do the same kind of local processing that the CNN does.&lt;/p&gt;

&lt;p&gt;Since applying a reference operator results in a simple sum-product, the result will not depend on the particular order in which we consider the nodes. As long as row \(i\) of the reference operator describes the connections of the node with attributes \(\mathbf{x}_i\), the result will be the same. We say that this kind of operation is &lt;strong&gt;equivariant to permutations of the nodes&lt;/strong&gt;. &lt;br /&gt;
This is good, because the particular order with which we consider the nodes is not important. Remember: we’re only interested in the structure – which nodes are connected to which.&lt;/p&gt;

&lt;p&gt;Now that we are able to aggregate information from a node’s neighborhood, we only need to solve the issue of how to create the learnable kernel and we will have a good first approximation of a CNN for graphs. Remember the two issues that we have:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Neighborhoods vary in size;&lt;/li&gt;
  &lt;li&gt;We don’t know how to orient the kernel (i.e., we may not have attributes that allow us to distinguish a node’s neighbors);&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;These problems are also related to our request that the GNN must be equivariant to permutations. We cannot simply assign a different weight to each neighbor because we would need to train the GNN on all possible permutations of the nodes in order to make it equivariant.&lt;/p&gt;

&lt;p&gt;However, there is a simple solution: &lt;strong&gt;use the same set of weights for each node in the neighborhood.&lt;/strong&gt;&lt;br /&gt;
Let our weights be a matrix \(\mathbf{\Theta} \in \mathbb{R}^{F \times F&apos;}\), so that the output will have \(F&apos;\) “feature maps”.&lt;/p&gt;

&lt;p&gt;Now, we simply use \(\mathbf{\Theta}\) to transform the node attributes, then sum them over using a reference operator.&lt;/p&gt;

&lt;p&gt;Let’s check the shapes to make sure that it works out: N-by-N times N-by-F times F-by-F’ equals N-by-F’.&lt;br /&gt;
We went from graph signal to graph signal, with new node attributes that we obtain as a local, learnable, and differentiable transformation.&lt;/p&gt;

&lt;p&gt;Done! We have our first GNN: \(\mathbf{X}&apos; = \mathbf{R} \mathbf{X} \mathbf{\Theta}\).&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-9.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;One thing that is still missing from our relatively simple implementation is the ability to have kernels that span more than the immediate neighborhood of a node. In fact, in a CNN this is usually a hyperparameter. Also, depending on the reference operator that we use, we may or may not consider a node itself when computing its transformation: it depends on whether \(\mathbf{R}\) has a non-zero diagonal.&lt;/p&gt;

&lt;p&gt;Luckily we can generalize the idea of a bigger kernel to the graph domain: we simply process each node as a function of its neighbors up to \(K\) steps away from it.&lt;/p&gt;

&lt;p&gt;We can achieve this by considering that applying a reference operator to a graph signal has the effect of making node attributes &lt;em&gt;flow&lt;/em&gt; through the graph. 
Apply a reference operator once, and all nodes will “read” from their immediate neighbors to update themselves. Apply it again, and all nodes will read again from their neighbors, except that this time the information that they read will be whatever the neighbors computed at the previous step.&lt;/p&gt;

&lt;p&gt;In other words: if we multiply a graph signal by \(\mathbf{R}^{K}\), each node will update itself with the node attributes of nodes \(K\) steps away.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-10.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;In a CNN, this would be equivalent to having a kernel shaped like an empty square.
To make the kernel full, we simply sum all “empty square” kernels up to the desired size. In our case, instead of considering \(\mathbf{R}^{K}\), we consider a polynomial of \(\mathbf{R}\) up to order \(K\).&lt;/p&gt;

&lt;p&gt;This is called a &lt;strong&gt;polynomial graph filter&lt;/strong&gt;, and we will see a different interpretation of it in Part 3 of this series.&lt;/p&gt;

&lt;p&gt;Note that this filter solves both problems that we had before, and also makes our GNN more expressive:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;The value of a node itself is always included in the transformation, since \(\mathbf{R}^{0} = \mathbf{I}\);&lt;/li&gt;
  &lt;li&gt;The sum of polynomials up to order \(K\) will necessarily cover all neighbors in a radius of \(K\) steps;&lt;/li&gt;
  &lt;li&gt;Since we can treat neighborhoods separately, we can also have different weights \(\mathbf{\Theta}^{(k)}\) for each \(k\)-hop neighborhood. This is like having a radial filter, a function that only depends on the radius from the origin.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-11.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;This idea of using a polynomial filter to create a GNN was first introduced in a paper by &lt;a href=&quot;https://arxiv.org/abs/1606.09375&quot;&gt;Defferrard et al.&lt;/a&gt;, which can be seen as the first scalable and practical implementation of a GNN ever proposed.&lt;/p&gt;

&lt;p&gt;In that paper they used a particular choice of polynomial, namely one for which different powers are defined in a recursive manner, called a &lt;strong&gt;Chebyshev polynomial&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;In particular, as reference operator they use a version of the graph Laplacian that is first normalized and then rescaled so that its eigenvalues are between -1 and 1.
Then, using the recursive formulation of Chebyshev polynomials, they build a polynomial graph filter.&lt;/p&gt;

&lt;p&gt;The reason why they use these polynomials and not the simple ones we saw above is not important, for now. Let us just say: they have some desirable properties and they are fast to compute.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2021-03-03/presentation-12.svg&quot; width=&quot;100%&quot; style=&quot;border: solid 1px;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Just a few months after the paper by Defferrard et al. was published on ArXiv, a new paper by &lt;a href=&quot;&quot;&gt;Kipf &amp;amp; Welling&lt;/a&gt; also appeared online.&lt;/p&gt;

&lt;p&gt;In that paper, the authors looked at the Chebyshev filter proposed by Defferrard et al. and introduced a few key changes to make the layer more simple and more scalable.&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;They changed the reference operator. Instead of the rescaled and normalized Laplacian, they assumed that \(\lambda_{max} = 2\) so that the whole formulation of the operator was simplified to \(-\mathbf{A}_n\).&lt;/li&gt;
  &lt;li&gt;They proposed to use polynomials of order 1, following the intuition that \(K\) layers of order 1 would be equivalent to 1 layer of order \(K\). In particular, they also added non-linearities between each successive layer, leading to more complex transformations of the nodes at each propagation step.&lt;/li&gt;
  &lt;li&gt;They observed that the same set of weights could be used both for a node itself and its neighbors. No need to have \(\mathbf{\Theta}^{(0)}\) and \(\mathbf{\Theta}^{(1)}\) as different weights.&lt;/li&gt;
  &lt;li&gt;After simplifying the layer down to 
\(\mathbf{X}&apos; = ( \mathbf{I} + \mathbf{A}_n) \mathbf{X} \mathbf{\Theta},\)
they observed that a more stable behavior could be obtained by instead using \(\mathbf{R} = \mathbf{D}^{-1/2} (\mathbf{I} + \mathbf{A}) \mathbf{D}^{-1/2}\) as reference operator.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Putting this all together, we get to what is commonly known as the Graph Convolutional Network (GCN):&lt;/p&gt;

\[\mathbf{X}&apos; = \mathbf{D}^{-1/2} (\mathbf{I} + \mathbf{A}) \mathbf{D}^{-1/2} \mathbf{X} \mathbf{\Theta}\]

&lt;hr /&gt;

&lt;p&gt;What we have seen so far is a very simple construction that takes the general concepts behind CNNs and, by changing a few assumptions, extends them to the case in which the input is an arbitrary graph instead of a grid.&lt;/p&gt;

&lt;p&gt;This is far from the whole story, but it should give you a good starting point to learn about GNNs.&lt;/p&gt;

&lt;p&gt;In the &lt;a href=&quot;https://danielegrattarola.github.io/posts/2021-03-12/gnn-lecture-part-2.html&quot;&gt;next part of this series&lt;/a&gt; we will see:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;How to describe what we just saw as a general algorithm that allows us to describe a much richer family of operations on graphs.&lt;/li&gt;
  &lt;li&gt;How to throw edge attributes in the mix and create GNNs that can treat neighbors differently.&lt;/li&gt;
  &lt;li&gt;How to make the entries of a reference operator a learnable function.&lt;/li&gt;
  &lt;li&gt;A general recipe for a GNN that &lt;em&gt;should&lt;/em&gt; work well for many problems.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Stay tuned.&lt;/p&gt;
</description>
        <pubDate>Wed, 03 Mar 2021 00:00:00 +0000</pubDate>
        
        <link>/posts/2021-03-03/gnn-lecture-part-1.html</link>
          
        
            <category>GNN</category>
        
            <category>lecture</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>Telestrations Neural Networks</title>
        <description>&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2020-01-21/telestrations.jpg&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Yesterday, it was board game day at &lt;a href=&quot;http://www.neurontobrainlaboratory.ca/&quot;&gt;the lab&lt;/a&gt; where I have been working recently. 
Everyone got together for lunch at Snakes &amp;amp; Lattes, a Torontonian board game cafè chain, and we spent a couple of hours laughing and chatting and, obviously, playing board games.&lt;/p&gt;

&lt;p&gt;The lab has a go-to traditional game for the occasion: &lt;a href=&quot;https://en.wikipedia.org/wiki/Telestrations&quot;&gt;Telestrations&lt;/a&gt;.
The game is inspired by the classic childhood’s game of &lt;a href=&quot;https://en.wikipedia.org/wiki/Chinese_whispers&quot;&gt;Chinese whispers&lt;/a&gt; (or &lt;em&gt;Telephone&lt;/em&gt;, or &lt;em&gt;Wireless phone&lt;/em&gt;, or &lt;em&gt;Gossip&lt;/em&gt;, there’s a bunch of different names for different countries) and its rules are pretty simple.&lt;/p&gt;

&lt;p&gt;Everyone gets a booklet, an erasable sharpie, and a list of random terms like “flamingo” or “pipe dream” or “treehouse”. Everyone picks a word and writes it on the first page of the booklet: that’s the secret source word.&lt;/p&gt;

&lt;p&gt;At each turn, players pass their booklet to the person on their right, and the rules are as follows:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;When you see a word, you turn the page and you have sixty seconds to &lt;em&gt;draw&lt;/em&gt; whatever the word is;&lt;/li&gt;
  &lt;li&gt;When you see a drawing, you turn the page and you write your best guess for what is pictured.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Players keep alternating between guessing, drawing, and passing down the booklets until every booklet has done a full round of the table and is back in the hands of the original owner. 
For extra fun, everybody gets to draw their secret source word at the very beginning.&lt;/p&gt;

&lt;p&gt;In other words, it’s a written game of Chinese whispers where every other word is drawn instead of written.&lt;/p&gt;

&lt;p&gt;There are some rules to decide who wins at the end, but the obvious source of entertainment is the complete chaos that ensues as information gets corrupted drawing after drawing. At the end of a round, not one of the original secret words ever survives.&lt;/p&gt;

&lt;p&gt;So now the obvious, rational, almost trivial question is: what happens when you use a GAN to draw, and an image classifier to guess? &lt;br /&gt;
Well, here I am to show you!&lt;/p&gt;

&lt;!--more--&gt;

&lt;h2 id=&quot;how-to-in-three-paragraphs&quot;&gt;How-to in three paragraphs&lt;/h2&gt;

&lt;p&gt;&lt;a href=&quot;https://openreview.net/forum?id=B1xsqj09Fm&quot;&gt;BigGAN&lt;/a&gt; can generate images conditioned on an ImageNet label. So if you give it label 1, it will generate goldfish, if you give it label 42, it will generate an agama, and so on.&lt;/p&gt;

&lt;p&gt;ResNet does the opposite: if you show it a goldfish, it will try to guess what it is. To make things more interesting, I added a bit of noise to the guessing procedure, so that sometimes we get a random one out of the top-5 guesses. If you think that this is unreasonable, try and play a game with real humans, I dare you.&lt;/p&gt;

&lt;p&gt;The idea now is to play the game using BigGAN to draw, and ResNet to guess: you start with a label, you have BigGAN generate an image of that label, you classify that image to get a new label, and so on.&lt;/p&gt;

&lt;h2 id=&quot;results&quot;&gt;Results&lt;/h2&gt;

&lt;p&gt;I’ll start with my favourite sequence: honeycomb to cheeseburger. 
The images below are read top-to-bottom, left-to-right. At the very top you see the source class, then the first generated image, then what that image was classified as, then the next generated image, etc..&lt;/p&gt;

&lt;p&gt;The first image is generated from class 599 of ImageNet, “honeycomb”. It looked a lot like a bagel, I guess because of that bright spot in the middle (?), so the ResNet classified it as such. From that classification, we get a couple of bagel-y looking pieces of bread, which soon become French loafs, then dough.&lt;br /&gt;
Then, that perfect-looking dough in image 6 gets classified as a wooden spoon (probably because of the extra noise that I mentioned). 
Finally, the green spot on the wooden spoon confuses ResNet into thinking it’s a cheeseburger, and we get juicy burgers until the end. That burger generation is impressive, not gonna lie.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2020-01-21/bagel.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Moving on: trilobite to long-horned beetle. 
The first two trilobites look really good, but then get classified as isopods after two turns (curiously, isopods and trilobites look a lot similar but are not that closely related according &lt;a href=&quot;https://www.reddit.com/r/geology/comments/lt9so/how_closely_related_are_isopods_to_trilobites/&quot;&gt;to Reddit&lt;/a&gt;).
From the isopod label, we get what is clearly a marine creature (look at the background), which unfortunately gets classified as a cockroach. From there, we stay on dry land and just get more and more specialized bugs until the end.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2020-01-21/trilobite.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The next one is &lt;strong&gt;REALLY&lt;/strong&gt; good because it’s remarkably similar to a real game of Telestrations. It could happen. Hell, it probably happened.&lt;/p&gt;

&lt;p&gt;We start with a coffeepot. At image three, the coffeepot is a bit ambiguous and becomes a teapot. Understandable, I would probably have made that mistake myself. Then we get a proper teapot, that gets recognized as such. 
The next image, however, is half-assed by the player and it’s not clear at all what it is. The next player guesses that it’s a pitcher. The next guy tries his best but eventually, the pitcher becomes a vase.&lt;/p&gt;

&lt;p&gt;Nothing more to say, I can see this happening in real life.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2020-01-21/coffeepot.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Our next and last one is also a likely sequence.&lt;/p&gt;

&lt;p&gt;A volcano. Easy. We get two perfect volcano drawings. Except that the last one gets classified as a type of tent.&lt;/p&gt;

&lt;p&gt;The next player over-does it, and draws a full camping spot with caravans instead of a tent. Curiously, we still have a volcano-looking thing in the background, but that’s just a coincidence (no information from previous images or labels is preserved between turns).&lt;/p&gt;

&lt;p&gt;The camp is seen as a bee house. Next thing we know, there’s a weird-looking BigGAN human harvesting honey. 
But ResNet doesn’t care about the human and focuses on the crate in the middle, instead.&lt;/p&gt;

&lt;p&gt;We get a good-looking crate, that becomes a chest, and we stay with chests until the end.&lt;/p&gt;

&lt;p&gt;The yurt and the apiary are the only weird ones in this sequence, and the least likely to appear in a human game. I can see someone drawing a full camping spot instead of a single yurt, and I can see how one would mistake a poorly-drawn volcano for a tent, but no human would ignore the beekeeper in image 4.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2020-01-21/volcano.png&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;I have generated a bunch of these sequences on my laptop, and these are just four random ones that I got. It’s really easy to get fun sequences. So here’s how I did it.&lt;/p&gt;

&lt;h2 id=&quot;code&quot;&gt;Code&lt;/h2&gt;

&lt;p&gt;First of all, I was not going to spend a single € to train anything involved in this project because, like, let’s be real…&lt;/p&gt;

&lt;p&gt;So I turned to Google and I found:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://github.com/huggingface/pytorch-pretrained-BigGAN&quot;&gt;A pre-trained BigGAN&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://pytorch.org/docs/stable/torchvision/index.html&quot;&gt;Torchvision’s pre-trained ResNet50&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;I usually write my stuff in TensorFlow but whatever, let’s PyTorch this one.&lt;/p&gt;

&lt;p&gt;We start with some essential imports:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;PIL&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Image&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;pytorch_pretrained_biggan&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BigGAN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;one_hot_from_int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;truncated_noise_sample&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;convert_to_images&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;spektral.utils&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;init_logging&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torchvision&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;models&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torchvision&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;and we define a couple of useful variables:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;iterations&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# How many players there are
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;standard_noise&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.3&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Some random noise because people are not perfect
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;current_class&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1001&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# The secret source word is random
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Load ImageNet class list
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;open&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;imagenet_classes.txt&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;line&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;strip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;line&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;readlines&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;imagenet_classes.txt&lt;/code&gt; &lt;a href=&quot;https://github.com/Lasagne/Recipes/blob/master/examples/resnet50/imagenet_classes.txt&quot;&gt;can be found online&lt;/a&gt;, it’s just a list of ImageNet class names.&lt;/p&gt;

&lt;p&gt;Now, let’s create the models that we will use. First we create the GAN:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;gan&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BigGAN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_pretrained&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;biggan-deep-256&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;gan&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;cuda&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Then, we create the ResNet50 ImageNet classifier:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;classifier&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;models&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;resnet50&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pretrained&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;classifier&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;eval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Do this to set the model to inference mode
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;and its image pre-processor:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Compose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CenterCrop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;224&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ToTensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;transforms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Normalize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.485&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.456&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.406&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;std&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.229&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.224&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.225&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
 &lt;span class=&quot;p&quot;&gt;)])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We will be drawing and guessing images of 256 x 256 pixels (cropped to 224 x 244 for ResNet50). The hard-coded normalization is just something that you have to do for Torchvision models, no biggie.&lt;/p&gt;

&lt;p&gt;So now we have loaded the networks. Let’s define some helper functions that will compute the main steps of the game for us:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;draw&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;label&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;truncation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Create the inputs for the GAN
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;class_vector&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;one_hot_from_int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;label&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;class_vector&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;class_vector&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;class_vector&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;class_vector&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;cuda&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;noise_vector&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;truncated_noise_sample&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;truncation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;truncation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;noise_vector&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_numpy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;noise_vector&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;noise_vector&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;noise_vector&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;cuda&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Generate image
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gan&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;noise_vector&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;class_vector&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;truncation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;cpu&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Get a PIL image from a Torch tensor
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;convert_to_images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;
    

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;guess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;top&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Pre-process image
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;transform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Classify image
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;classification&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;classifier&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;unsqueeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;classification&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;descending&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;percentage&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;functional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;classification&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Get the global ImageNet class, labels, and the predicted probabilities
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;idxs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]][:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;labs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]][:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;percentage&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;item&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]][:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idxs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;labs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now we can start playing!&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;output_imgs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Stores the drawings
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output_labels&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Stores the guesses
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output_labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;current_class&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Main game loop
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;iterations&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Draw an image
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;img&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;draw&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;current_class&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output_imgs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Guess what the image is
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;idxs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;labs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;guess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;img&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;top&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Add noise
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;uniform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;standard_noise&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Re-normalize because of noise
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Choose from the predictions
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;current_class&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;idxs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output_labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;choice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;At the end of the game, we will have the generated drawings in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;output_imgs&lt;/code&gt; and the guesses in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;output_labels&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Here, instead of copy-pasting from the cells above you can just look at &lt;a href=&quot;https://gist.github.com/danielegrattarola/8296b9fd29116443da74d0aa2519d7c3&quot;&gt;the full gist&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;conclusions&quot;&gt;Conclusions&lt;/h2&gt;
&lt;p&gt;What can I say? It’s neural networks playing Telestrations.&lt;/p&gt;

&lt;p&gt;“No new knowledge can be extracted from my telling. This confession has meant nothing.”&lt;/p&gt;

&lt;p&gt;Cheers!&lt;/p&gt;

&lt;hr /&gt;

&lt;p&gt;In case you didn’t know: agamas (label 42 of ImageNet) are extra-fucking-cool lizards.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2020-01-21/agama.jpg&quot; alt=&quot;&quot; /&gt;&lt;/p&gt;
</description>
        <pubDate>Tue, 21 Jan 2020 00:00:00 +0000</pubDate>
        
        <link>/posts/2020-01-21/telestrations.html</link>
          
        
            <category>AI</category>
        
            <category>random</category>
        
            <category>code</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>Pitfalls of Graph Neural Network Evaluation 2.0</title>
        <description>&lt;p&gt;In this post, I’m going to summarize some conceptual problems that I have found when comparing different graph neural networks (GNNs) between them.&lt;/p&gt;

&lt;p&gt;I’m going to argue that it is extremely difficult to make an objectively fair comparison between structurally different models and that the experimental comparisons found in the literature are not always sound.&lt;/p&gt;

&lt;p&gt;I will try to suggest reasonable solutions whenever possible, but the goal of this post is simply to make these issues appear on your radar and maybe spark a conversation on the matter.&lt;/p&gt;

&lt;p&gt;Some of the things that I’ll say are also addressed in the original &lt;a href=&quot;https://arxiv.org/abs/1811.05868&quot;&gt;Pitfalls of Graph Neural Network Evaluation (Shchur et al., 2018)&lt;/a&gt;, which I warmly suggest you read.&lt;/p&gt;

&lt;!--more--&gt;

&lt;h2 id=&quot;neighbourhoods&quot;&gt;Neighbourhoods&lt;/h2&gt;

&lt;p&gt;The first source of inconsistency when comparing GNNs comes from the fact that different layers are designed to take into account neighborhoods of different sizes.&lt;br /&gt;
We usually have that a layer either looks at the 1-neighbours of each node, or it has a hyperparameter K that controls the size of the neighbourhood. Some examples of popular methods (implemented both in Spektral and Pytorch Geometric) in either category:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;1-hop: &lt;a href=&quot;https://arxiv.org/abs/1609.02907&quot;&gt;GCN&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/1710.10903&quot;&gt;GAT&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/1706.02216&quot;&gt;GraphSage&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/1810.00826&quot;&gt;GIN&lt;/a&gt;;&lt;/li&gt;
  &lt;li&gt;K-hop: &lt;a href=&quot;https://arxiv.org/abs/1606.09375&quot;&gt;Cheby&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/1901.01343&quot;&gt;ARMA&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/1810.05997&quot;&gt;APPNP&lt;/a&gt;, &lt;a href=&quot;https://arxiv.org/abs/1902.07153&quot;&gt;SGC&lt;/a&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;A fair evaluation should keep these differences into account and allow each GNN to look at the same neighborhoods, but at the same time, it could be argued that a layer designed to operate on larger neighborhoods is more expressive. How can we tell what is better?&lt;/p&gt;

&lt;p&gt;Let’s say we are comparing GCN with Cheby. The equivalent of a 2-layer GCN could be a 2-layer Cheby with K=1, or a 1-layer Cheby with K=2. In the GCN paper, they use a 2-layer Cheby with K=3. Should they have compared with a 6-layer GCN?&lt;/p&gt;

&lt;p&gt;Moreover, this difference between methods may have an impact on the number of parameters, nonlinearity, and overall amount of regularization in a GNN. &lt;br /&gt;
For instance, a GCN that reaches a neighborhood of order 3 may have 3 dropout layers, while the equivalent Cheby with K=3 will have only one.  &lt;br /&gt;
Another example: an SGC architecture can reach any neighborhood with a constant number of parameters, while other methods can’t.&lt;/p&gt;

&lt;p&gt;We’re only looking at one simple issue, and it is already difficult to say how to fairly evaluate different methods. It gets worse.&lt;/p&gt;

&lt;h2 id=&quot;regularization-and-training&quot;&gt;Regularization and training&lt;/h2&gt;

&lt;p&gt;Regularization is an aspect that is particularly essential in GNNs, because the community uses very small benchmark datasets and most GNNs tend to overfit like crazy (more on this later).
For these reasons, the performance of a GNN can vary wildly depending on how the model is regularized. This is true for all other hyperparameters in general, because things like the learning rate and batch size can be a form of implicit regularization.&lt;/p&gt;

&lt;p&gt;The literature is largely inconsistent with how regularization is applied across different papers, making it difficult to say whether the performance improvements reported for a model are due to the actual contribution or to a different regularization scheme.&lt;/p&gt;

&lt;p&gt;The following are often found in the literature:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;High learning rates;&lt;/li&gt;
  &lt;li&gt;High L2 penalty;&lt;/li&gt;
  &lt;li&gt;Extremely high dropout rates on node features and adjacency matrix;&lt;/li&gt;
  &lt;li&gt;Low number of training epochs;&lt;/li&gt;
  &lt;li&gt;Low patience for early stopping.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;I’m going to focus on a few of these.&lt;/p&gt;

&lt;p&gt;First, I argue that setting a fixed number of training epochs is a form of alchemy that should be avoided if possible, because it’s incredibly task-specific. Letting a model train to convergence is almost always a better approach, because it’s less dependent on the initialization of the weights. If the validation performance is not indicative of the test performance and we need to stop the training without a good criterion, then something is probably wrong.&lt;/p&gt;

&lt;p&gt;A second important aspect that I feel gets overlooked often is dropout. &lt;br /&gt;
In particular, when dropout is applied to the adjacency matrix it leads to big performance improvements, because the GNN is exposed to very noisy instances of the graphs at each training step and is forced to generalize well. &lt;br /&gt;
When comparing different models, if one is using dropout on the adjacency matrix then all the others should do the same. However, the common practice of comparing methods using the “same architecture from the original paper” means that some methods will be tested with dropout on A, and some without, as if the dropout is a particular characteristic of only some methods.&lt;/p&gt;

&lt;p&gt;Finally, the remaining key factors in training are the learning rate and weight decay. 
These are often given as-is in the literature, but it is a good idea to tune them whenever possible. For what it’s worth, I can personally confirm that searching for a good learning rate, in particular, can lead to unexpected results, even for well-established methods (if the model is trained to convergence).&lt;/p&gt;

&lt;h2 id=&quot;parallel-heads&quot;&gt;Parallel heads&lt;/h2&gt;

&lt;p&gt;&lt;em&gt;Heads&lt;/em&gt; are parallel computational units that perform the same calculation with different weights and then merge the results to produce the output. To give a sense of the problems that one may encounter when comparing methods that use heads, I will focus on two methods: GAT and ARMA.&lt;/p&gt;

&lt;p&gt;Having parallel attention heads is fairly common in NLP literature, from where the very concept of attention comes, and therefore it was natural to do the same in GAT.&lt;/p&gt;

&lt;p&gt;In ARMA, using parallel &lt;em&gt;stacks&lt;/em&gt; is theoretically motivated by the fact that ARMA filters of order H can be computed by summing H ARMA filters of order 1. While similar in practice to the heads in GAT, in this case having parallel heads is key to the implementation of this particular graph filter.&lt;/p&gt;

&lt;p&gt;Because of these fundamental semantic differences, it is impossible to say whether a comparison between GAT with H heads and an ARMA layer of order H is fair.&lt;/p&gt;

&lt;p&gt;Extending to the other models as well, it is not guaranteed that having parallel heads would necessarily lead to any practical improvements for a given model. Some methods can, in fact, benefit from a simpler architecture. 
It is therefore difficult to say whether a comparison between monolithic and parallel architectures is fair.&lt;/p&gt;

&lt;h2 id=&quot;datasets&quot;&gt;Datasets&lt;/h2&gt;

&lt;p&gt;Finally, I’m going to spend a few words on datasets, because there is no chance of having a fair evaluation if the datasets on which we test our models are not good. And in truth, the benchmark datasets that we use for evaluating GNNs are not that good.&lt;/p&gt;

&lt;p&gt;Cora, CiteSeer, PubMed, and the Dortmund benchmark datasets for graph kernels: these are, collectively, the Iris dataset of GNNs, and should be treated carefully. While a model should work on these in order to be considered usable, they cannot be the only criterion to run a fair evaluation.&lt;/p&gt;

&lt;p&gt;Recently, the community has moved towards a more sensible use of the datasets (ok, maybe I was exaggerating a bit about Iris), thanks to papers like &lt;a href=&quot;https://arxiv.org/abs/1811.05868&quot;&gt;this&lt;/a&gt; and &lt;a href=&quot;https://arxiv.org/abs/1910.12091&quot;&gt;this&lt;/a&gt;. However, many experiments in the literature still had to be repeated hundreds of times in order to give significant results, and that is bad for three reasons: time, money, and the environment, in no particular order.&lt;br /&gt;
Especially if running a grid search of hyperparameters, it just doesn’t make sense to be using datasets that require that much computation to give reliable outcomes, more so if we consider that these are supposed to be &lt;em&gt;easy&lt;/em&gt; datasets.&lt;/p&gt;

&lt;p&gt;Personally, I find that there are better alternatives out there, that however are not considered often. For node classification, the GraphSage datasets (PPI and Reddit) are significantly better benchmarks than the citation networks (although they’re inductive tasks). 
For graph-level learning, QM9 has 134k small graphs, of variable order, and will lead to minuscule uncertainty about the results after a few runs. I realize that it is a dataset for regression, but it still is a better alternative to PROTEINS. 
For classification, Filippo Bianchi, with whom I’ve recently worked a lot, released a dataset that simply cannot be classified without using a GNN. You can find it &lt;a href=&quot;https://github.com/FilippoMB/Benchmark_dataset_for_graph_classification&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;I will admit that I am as guilty as the next person when it comes to using the “bad” datasets mentioned above. One reason is that it is easy to not move away from what everybody else is doing. One reason is that reviewers outright ask for them if you don’t include them, caring little for anything else.&lt;/p&gt;

&lt;p&gt;I think we can do better, as a community.&lt;/p&gt;

&lt;h2 id=&quot;in-conclusion&quot;&gt;In conclusion&lt;/h2&gt;

&lt;p&gt;I started thinking seriously about these issues as I was preparing a paper that required me to compare several models for the experiments. 
I am not sure whether the few solutions that I have outlined here are definitive, or even correct, but I feel that this is a conversation that needs to be had in the field of GNNs.&lt;/p&gt;

&lt;p&gt;Many of the comparisons that are found in the wild do not take any of this stuff into account, and I think that this may ultimately slow the progress of GNN research and its propagation to other fields of science.&lt;/p&gt;

&lt;p&gt;If you want to continue this conversation, or if you have any ideas that could complement this post, shoot me an email or look for me on &lt;a href=&quot;https://twitter.com/riceasphait&quot;&gt;Twitter&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Cheers!&lt;/p&gt;
</description>
        <pubDate>Fri, 13 Dec 2019 00:00:00 +0000</pubDate>
        
        <link>/posts/2019-12-13/pitfalls.html</link>
          
        
            <category>AI</category>
        
            <category>GNN</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>Implementing a Network-based Model of Epilepsy with Numpy and Numba</title>
        <description>&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-10-03/2_nodes_complex_plane.png&quot; alt=&quot;&quot; class=&quot;full-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Mathematically modeling how epilepsy acts on the brain is one of the major topics of research in neuroscience. 
Recently I came across &lt;a href=&quot;https://mathematical-neuroscience.springeropen.com/articles/10.1186/2190-8567-2-1&quot;&gt;this paper&lt;/a&gt; by Oscar Benjamin et al., which I thought that it would be cool to implement and experiment with.&lt;/p&gt;

&lt;p&gt;The idea behind the paper is simple enough. First, they formulate a mathematical model of how a seizure might happen in a single region of the brain. Then, they expand this model to consider the interplay between different areas of the brain, effectively modeling it as a network.&lt;/p&gt;

&lt;!--more--&gt;

&lt;h2 id=&quot;single-system&quot;&gt;Single system&lt;/h2&gt;

&lt;p&gt;We start from a complex dynamical system defined as follows:&lt;/p&gt;

\[\dot{z} = f(z) = (\lambda - 1 + i \omega)z + 2z|z|^2 - z|z|^4\]

&lt;p&gt;where \( z \in \mathbb{C} \) and \(\lambda\) controls the possible attractors of the system. 
For \( 0 &amp;lt; \lambda &amp;lt; 1 \), the system has two stable attractors: one fixed point and one attractor that oscillates with an angular velocity of \(\omega\) rad/s.&lt;br /&gt;
We can consider the stable attractor as a simplification of the brain in its resting state, while the oscillating attractor is taken to be the &lt;em&gt;ictal&lt;/em&gt; state (i.e., when the brain is having a seizure).&lt;/p&gt;

&lt;p&gt;We can also consider a &lt;em&gt;noise-driven&lt;/em&gt; version of the system:&lt;/p&gt;

\[dz(t) = f(z)\,dt + \alpha\,dW(t)\]

&lt;p&gt;where \( W(t) \) is a Wiener process rescaled by a factor \( \alpha \).&lt;br /&gt;
A Wiener process \( W(t)_{t\ge0} \), sometimes called &lt;em&gt;Brownian motion&lt;/em&gt;, is a stochastic process with the following properties:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;\(W(0) = 0\);&lt;/li&gt;
  &lt;li&gt;the increments between two consecutive observations are normally distributed with a variance equal to the time between the observations:&lt;/li&gt;
&lt;/ul&gt;

\[W(t + \tau) - W(t) \sim \mathcal{N}(0, \tau).\]

&lt;p&gt;In the noise-driven version of the system, it is guaranteed that the system will eventually &lt;em&gt;escape&lt;/em&gt; any region of phase space, moving from one attractor to the other.&lt;/p&gt;

&lt;p&gt;In short, we have a system that due to external, unpredictable inputs (the noise), will randomly switch from a state of rest to a state of oscillation, which we consider as a seizure.&lt;/p&gt;

&lt;p&gt;The two figures below show an example of the system starting from the stable attractor and then moving to the oscillator. 
Since the system is complex, we can observe its dynamics in phase space:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-10-03/1_nodes_complex_plane.png&quot; alt=&quot;&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Or we can observe the real part of \( f(t) \) as if we were reading an EEG of brain activity:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-10-03/1_nodes_re_v_time.png&quot; alt=&quot;&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;p&gt;See how the change of attractor almost looks like an epileptic seizure?&lt;/p&gt;

&lt;h2 id=&quot;network-model&quot;&gt;Network model&lt;/h2&gt;

&lt;p&gt;While this simple model of seizure initiation is interesting on its own, we can also take our modeling a step further and explicitly represent the connections between different areas of the brain (or sub-systems, if you will) and how they might affect the propagation of seizures from one area to the other.&lt;/p&gt;

&lt;p&gt;We do this by defining a connectivity matrix \( A \) where \( A_{ij} = 1 \) if sub-system \( i \) has a direct influence on sub-system \( j \), and \( A_{ij} = 0 \) otherwise. In practice, we also normalize the matrix by dividing each row element-wise by the product of the square roots of the node’s out-degree and in-degree.&lt;/p&gt;

&lt;p&gt;Starting from the system described above, the dynamics of one node in the networked system are described by:&lt;/p&gt;

\[dz_{i}(t) = \big( f(z_i) + \beta \sum\limits_{j \ne i} A_{ji} (z_j - z_i) \big) + \alpha\,dW_{i}(t)\]

&lt;p&gt;If we look at the individual nodes, their behavior may not seem different than what we had with the single sub-system, but in reality, the attractors of these networked systems are determined by the connectivity \( A \) and the coupling strength \( \beta \).&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-10-03/4_graph.png&quot; alt=&quot;&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Here’s what the networked system of 4 nodes pictured above looks like in phase space:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-10-03/4_nodes_complex_plane.png&quot; alt=&quot;&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;p&gt;And again we can also look at the real part of each node:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-10-03/4_nodes_re_v_time.png&quot; alt=&quot;&quot; class=&quot;centered&quot; /&gt;&lt;/p&gt;

&lt;p&gt;If you want to have more details on how to control the different attractors of the system, I suggest you look at the &lt;a href=&quot;https://mathematical-neuroscience.springeropen.com/articles/10.1186/2190-8567-2-1&quot;&gt;original paper&lt;/a&gt;. They analyze in depth the attractors and &lt;em&gt;escape times&lt;/em&gt; of all possible 2-nodes and 3-nodes networks, as well as giving an overview of higher-order networks.&lt;/p&gt;

&lt;h2 id=&quot;implementing-the-system-with-numpy-and-numba&quot;&gt;Implementing the system with Numpy and Numba&lt;/h2&gt;

&lt;p&gt;Now that we got the math sorted out, let’s look at how to translate this system in Numpy.&lt;/p&gt;

&lt;p&gt;Since the system is so precisely defined, we only need to convert the mathematical formulation into code. In short, we will need:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;The core functions to compute the complex dynamical system;&lt;/li&gt;
  &lt;li&gt;The main loop to compute the evolution of the system starting from an initial condition.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;While developing this, I quickly realized that my original, kinda straightforward implementation was painfully slow and that it would have required some optimization to be usable.&lt;/p&gt;

&lt;p&gt;This was the perfect occasion to use &lt;a href=&quot;http://numba.pydata.org/&quot;&gt;Numba&lt;/a&gt;, a JIT compiler for Python that claims to yield speedups of up to two orders of magnitude.&lt;br /&gt;
Numba can be used to JIT compile any function implemented in pure Python, and natively supports a vast number of Numpy operations as well. 
The juicy part of Numba consists of compiling functions in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nopython&lt;/code&gt; mode, meaning that the code will run without ever using the Python interpreter. 
To achieve this, it is sufficient to decorate your functions with the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;@njit&lt;/code&gt; decorator and then simply run your script as usual.&lt;/p&gt;

&lt;h2 id=&quot;code&quot;&gt;Code&lt;/h2&gt;

&lt;p&gt;At the very start, let’s deal with imports and define a couple of helper functions that we are going to use only once:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numba&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;njit&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;degree_power&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;pow&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    Computes D^{p} from the given adjacency matrix.

    :param adj: rank 2 array.
    :param pow: exponent to which elevate the degree matrix.
    :return: the exponentiated degree matrix.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;degrees&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;power&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;pow&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;degrees&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;isinf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;degrees&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;degrees&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;normalized_adjacency&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    Normalizes the given adjacency matrix using the degree matrix as
    D^{-1/2}AD^{-1/2} (symmetric normalization).

    :param adj: rank 2 array.
    :return: the normalized adjacency matrix.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;normalized_D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;degree_power&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;normalized_D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;normalized_D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The code for these functions was copy-pasted from &lt;a href=&quot;https://danielegrattarola.github.io/spektral/&quot;&gt;Spektral&lt;/a&gt; and slightly adapted so that we don’t need to import the entire library just for two functions. Note that there’s no need to JIT compile these two functions because they will run only once, and in fact, it is not guaranteed that compiling them will be less expensive than simply executing them with Python. Especially because both functions are heavily Numpy-based already, so they should run at C-like speed.&lt;/p&gt;

&lt;p&gt;Moving forward to implementing the actual system. Let’s first define the fixed hyper-parameters of the model:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;# Frequency of oscillations in rad/s
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;alpha&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.2&lt;/span&gt;              &lt;span class=&quot;c1&quot;&gt;# Intensity of the noise
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;# Controls the possible attractors of each node
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;# Coupling strength b/w nodes
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;                    &lt;span class=&quot;c1&quot;&gt;# Number of nodes in the system
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seconds_to_generate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Number of seconds to evolve the system for
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0001&lt;/span&gt;              &lt;span class=&quot;c1&quot;&gt;# Time interval between consecutive states
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Random connectivity matrix
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fill_diagonal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;A_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;normalized_adjacency&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;complex128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The core of the dynamical system is the update function \( f(z) \), that in code looks like this:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;njit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;The deterministic update function of each node.

    :param z: complex, the current state.
    :param lamb: float, hyper-parameter to control the attractors of each node.
    :param omega: float, frequency of oscillations in rad/s.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;complex&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;
            &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;abs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;abs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;There’s not much to say here, except that using &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;complex&lt;/code&gt; instead of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.complex&lt;/code&gt; seems to be slightly faster (157 ns vs. 178 ns), although the performance impact on the overall function is clearly negligible.&lt;/p&gt;

&lt;p&gt;To compute the noise-driven system, we need to define the increment function of a complex Wiener process. We can start by implementing the increment function of a simple Wiener process, first:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;njit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;delta_wiener&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;Returns the random delta between two consecutive steps of a Wiener
    process (Brownian motion).

    :param size: tuple, desired shape of the output array.
    :param dt: float, time increment in seconds.
    :return: numpy array with shape &apos;size&apos;.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;At the time of writing this, Numba &lt;a href=&quot;https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#distributions&quot;&gt;does not support&lt;/a&gt; the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;size&lt;/code&gt; argument in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.random.normal&lt;/code&gt; but it does support &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.random.randn&lt;/code&gt;. Instead of setting the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scale&lt;/code&gt; parameter explicitly, we simply multiply the sampled values by the scale.&lt;br /&gt;
Since we are using the scale, and not the variance, we have to take the square root of the time increment &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dt&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Finally, we can compute the increment of a complex Wiener process as \( U(t) + jV(t) \), where both \( U \) and \( V \) are simple Wiener processes:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;njit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;complex_delta_wiener&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;Returns the random delta between two consecutive steps of a complex
    Wiener process (Brownian motion). The process is calculated as u(t) + jv(t)
    where u and v are simple Wiener processes.

    :param size: tuple, the desired shape of the output array.
    :param dt: float, time increment in seconds.
    :return: numpy array of np.complex128 with shape &apos;size&apos;.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;u&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;delta_wiener&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;delta_wiener&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;u&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1j&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now that we have all the necessary components to define the noise-driven system, let’s implement the main step function:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;njit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    Compute one time step of the system, s.t. z[t+1] = z[t] + step(z[t]).

    :param z: numpy array of np.complex128, the current state.
    :return: numpy array of np.complex128.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Matrix with pairwise differences of nodes
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;delta_z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Compute diffusive coupling
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;diffusive_coupling&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;delta_z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Compute change in state
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;update_from_self&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;update_from_others&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;diffusive_coupling&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;noise&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;alpha&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;complex_delta_wiener&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;dz&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update_from_self&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update_from_others&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;noise&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dz&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Originally, I had implemented the following line&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;delta_z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;as&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;delta_z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[...,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;but Numba does not support adding new axes with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt; or &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.newaxis&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Also, when computing &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;diffusive_coupling&lt;/code&gt;, a more efficient way of doing&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;would have been&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;ij,ij-&amp;gt;j&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;for reasons which I still fail to understand (3.48 µs vs. 2.57 µs, when &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;A&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;B&lt;/code&gt; are 3 by 3 float matrices). However, Numba does not support &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.einsum&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Finally, we can implement the main loop function that starts from a given initial state &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;z0&lt;/code&gt; and computes &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;steps&lt;/code&gt; number of updates at time intervals of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dt&lt;/code&gt;.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;njit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;evolve_system&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    Evolve the system starting from the given initial state (z0) for a given
    number of time steps (steps).

    :param z0: numpy array of np.complex128, the initial state.
    :param steps: int, number of steps to evolve the system for.
    :return: list, the sequence of states.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;steps_in_percent&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps_in_percent&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps_in_percent&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;%&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dz&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dz&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;I had originally wrapped the loop in a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tqdm&lt;/code&gt; progress bar, but an old-fashioned &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;if&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;print&lt;/code&gt; can reduce the overhead by 50% (2.29s vs. 1.23s, tested on a simple &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;for&lt;/code&gt; loop with 1e7 iterations). Pre-computing &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;steps_in_percent&lt;/code&gt; also reduces the overhead by 30% compared to computing it every time.&lt;br /&gt;
(You’ll notice that at some point it just became a matter of optimizing every possible aspect of this :D)&lt;/p&gt;

&lt;p&gt;The only thing left to do is to evolve the system starting from a given intial state:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;z0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;complex128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Starting conditions
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seconds_to_generate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# Number of steps to generate
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;timesteps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evolve_system&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;timesteps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;timesteps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;You can now run any analysis on &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;timesteps&lt;/code&gt;, which will be a Numpy array of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;np.complex128&lt;/code&gt;. Note also how we had to cast the initial conditions &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;z0&lt;/code&gt; to this &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dtype&lt;/code&gt;, in order to have strict typing in the JIT-compiled code.&lt;/p&gt;

&lt;p&gt;&lt;a href=&quot;https://gist.github.com/danielegrattarola/c663346b529e758f0224c8313818ad77&quot;&gt;I published the full code as a Gist, including the code I used to make the plots.&lt;/a&gt;&lt;/p&gt;

&lt;h2 id=&quot;general-notes-on-performance&quot;&gt;General notes on performance&lt;/h2&gt;

&lt;p&gt;My original implementation was based on a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Simulator&lt;/code&gt; class that implemented all the same methods in a compact abstraction:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Simulator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;object&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1e-4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;alpha&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.05&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lamb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;omega&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;delta_weiner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;complex_delta_weiner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;evolve_system&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;z0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;There were some issues with this implementation, the biggest one being that it is much more messy to JIT compile an entire class with Numba (the substance of the code did not change much, and I’ve explicitly highlighted all implementation changes above).&lt;/p&gt;

&lt;p&gt;Having moved to a more functional style feels cleaner and it honestly looks more elegant (opinions, I know). Crucially, it also allowed me to optimize each function to work flawlessly with Numba.&lt;/p&gt;

&lt;p&gt;After optimizing all that was optimizable, I tested the old code against the new one and the speedup was about 31x, going from ~8k iterations/s to ~250k iterations/s.&lt;/p&gt;

&lt;p&gt;Most of the improvement came from Numba and removing the overhead of Python’s interpreter, but it must be said that the true core of the system is dealt with by Numpy. In fact, as we increase the number of nodes the bottleneck becomes the matrix multiplication in Numpy, eventually leading to virtually no performance difference between using Numba or not (verified for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;N=1000&lt;/code&gt; - the 31x speedup was for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;N=2&lt;/code&gt;).&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;
I hope that you enjoyed this post and hopefully learned something new, be it about models of the epileptic brain or Python optimization.&lt;/p&gt;

&lt;p&gt;Cheers!&lt;/p&gt;
</description>
        <pubDate>Thu, 03 Oct 2019 00:00:00 +0000</pubDate>
        
        <link>/posts/2019-10-03/epilepsy-model.html</link>
          
        
            <category>tutorial</category>
        
            <category>code</category>
        
            <category>epilepsy</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>MinCUT Pooling in Graph Neural Networks</title>
        <description>&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-07-25/horses.png&quot; alt=&quot;Embeddings&quot; class=&quot;full-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;In &lt;a href=&quot;https://arxiv.org/abs/1907.00481&quot;&gt;our latest paper&lt;/a&gt;, we presented a new pooling method for GNNs, called &lt;strong&gt;MinCutPool&lt;/strong&gt;, which has a lot of desirable properties as far as pooling goes:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;It’s based on well-understood theoretical techniques for node clustering;&lt;/li&gt;
  &lt;li&gt;It’s fully differentiable and learnable with gradient descent;&lt;/li&gt;
  &lt;li&gt;It depends directly on the task-specific loss on which the GNN is being trained, but …&lt;/li&gt;
  &lt;li&gt;It can be trained on its own without a task-specific loss if needed;&lt;/li&gt;
  &lt;li&gt;It’s fast;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The method is based on the minCUT optimization problem, which consists of finding a cut on a weighted graph in such a way that the overall weight of the cut is minimized. We considered a continuous relaxation of the minCUT problem and implemented it as a neural network layer to provide a sound pooling method for GNNs.&lt;/p&gt;

&lt;p&gt;In this post, I’ll describe the working principles of minCUT pooling and show some applications of the layer.&lt;/p&gt;

&lt;!--more--&gt;

&lt;h2 id=&quot;background&quot;&gt;Background&lt;/h2&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-07-25/mincut_problem.png&quot; alt=&quot;Embeddings&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The &lt;a href=&quot;https://en.wikipedia.org/wiki/Minimum_k-cut&quot;&gt;K-way normalized minCUT&lt;/a&gt; is an optimization problem to find K clusters on a graph by minimizing the overall intra-cluster edge weight. This is equivalent to solving:&lt;/p&gt;

\[\text{maximize} \;\; \frac{1}{K} \sum_{k=1}^K \frac{\sum_{i,j \in \mathcal{V}_k} \mathcal{E}_{i,j} }{\sum_{i \in \mathcal{V}_k, j \in \mathcal{V} \backslash \mathcal{V}_k} \mathcal{E}_{i,j}},\]

&lt;p&gt;where \(\mathcal{V}\) is the set of nodes, \(\mathcal{V_k}\) is the \(k\)-th cluster of nodes, and \(\mathcal{E_{i, j}}\) indicates a weighted edge between two nodes.&lt;/p&gt;

&lt;p&gt;If we define a &lt;strong&gt;cluster assignment matrix&lt;/strong&gt; \(C \in \{0,1\}^{N \times K}\), which maps each of the \(N\) nodes to one of the \(K\) clusters, the problem can also be re-written as:&lt;/p&gt;

\[\text{maximize} \;\;  \frac{1}{K} \sum_{k=1}^K \frac{C_k^T A C_k}{C_k^T D C_k}\]

&lt;p&gt;where \(A\) is the adjacency matrix of the graph, and \(D\) is the diagonal degree matrix.&lt;/p&gt;

&lt;p&gt;While finding the optimal minCUT is an NP-hard problem, there exist relaxations that can find near-optimal solutions in polynomial time. These relaxations, however, are still very expensive and are not able to generalize to unseen samples.&lt;/p&gt;

&lt;h2 id=&quot;mincut-pooling&quot;&gt;MinCUT pooling&lt;/h2&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-07-25/GNN_pooling.png&quot; alt=&quot;Embeddings&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The idea behind minCUT pooling is to take a continuous relaxation of the minCUT problem and implement it as a GNN layer with a custom loss function. By minimizing the custom loss, the GNN learns to find minCUT clusters on any given graph and aggregates the clusters to reduce the graph’s size. &lt;br /&gt;
At the same time, because the layer can be used as a part of a larger architecture, any other loss that is being minimized during training will influence the clusters found by MinCutPool, making them optimal for the particular task at hand.&lt;/p&gt;

&lt;p&gt;At the core of minCUT pooling there is a MLP, which maps the node features \(\mathbf{X}\) to a &lt;strong&gt;continuous&lt;/strong&gt; cluster assignment matrix \(\mathbf{S}\) (of size \(N \times K\)):&lt;/p&gt;

\[\mathbf{S} = \textrm{softmax}(\text{ReLU}(\mathbf{X}\mathbf{W}_1)\mathbf{W}_2)\]

&lt;p&gt;We can then use the MLP to generate \(\mathbf{S}\) on the fly, and reduce the graphs with simple multiplications as:&lt;/p&gt;

\[\mathbf{A}^{pool} = \mathbf{S}^T \mathbf{A} \mathbf{S}; \;\;\; \mathbf{X}^{pool} = \mathbf{S}^T \mathbf{X}.\]

&lt;p&gt;At this point, we can already make a couple of considerations:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;Nodes with similar features will likely belong to the same cluster because they will be “classified” similarly by the MLP. This is especially true when using message-passing layers before pooling, since they will cause the node features of connected nodes to become similar;&lt;/li&gt;
  &lt;li&gt;Because of the MLP, \(\mathbf{S}\) is pretty fast to compute and the layer can generalize to new graphs once it has been trained.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This is already pretty good, and it covers some of the main desiderata of a GNN layer, but we also want to explicitly account for the connectivity of the graph in order to pool it.&lt;/p&gt;

&lt;p&gt;This is where the minCUT optimization comes in.&lt;/p&gt;

&lt;p&gt;By slightly adapting the minCUT formulation above, we can design an auxiliary loss to train the MLP, so that it will learn to solve the minCUT problem in an unsupervised way. &lt;br /&gt;
In practice, our unsupervised regularization loss encourages the MLP to cluster together nodes that are strongly connected with each other and weakly connected with the nodes in the other clusters.&lt;/p&gt;

&lt;p&gt;The full unsupervised loss that we minimize in order to achieve this is:&lt;/p&gt;

\[\mathcal{L}_u = \mathcal{L}_c + \mathcal{L}_o = 
    \underbrace{- \frac{Tr ( \mathbf{S}^T \mathbf{A} \mathbf{S} )}{Tr ( \mathbf{S}^T\mathbf{D} \mathbf{S})}}_{\mathcal{L}_c} + 
    \underbrace{\bigg{\lVert} \frac{\mathbf{S}^T\mathbf{S}}{\|\mathbf{S}^T\mathbf{S}\|_F} - \frac{\mathbf{I}_K}{\sqrt{K}}\bigg{\rVert}_F}_{\mathcal{L}_o},\]

&lt;p&gt;where \(\mathbf{A}\) is the &lt;a href=&quot;https://danielegrattarola.github.io/spektral/utils/convolution/#normalized_adjacency&quot;&gt;normalized&lt;/a&gt; adjacency matrix of the graph.&lt;/p&gt;

&lt;p&gt;Let’s break this loss down and see how it works.&lt;/p&gt;

&lt;h3 id=&quot;cut-loss&quot;&gt;Cut loss&lt;/h3&gt;
&lt;p&gt;The first term, \(\mathcal{L}_c\), encourages the MLP to find cluster assignments that solve the minCUT problem (to see why, compare it with the minCUT maximization that I described above). We refer to this loss as the &lt;strong&gt;cut loss&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;In particular, minimizing the numerator leads to clustering together nodes that are strongly connected on the graph, while the denominator prevents any of the clusters to be too small.&lt;/p&gt;

&lt;p&gt;The cut loss is bounded between -1 and 0, which are &lt;strong&gt;ideally&lt;/strong&gt; reached in the following situations:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;\(\mathcal{L}_c = 0\) when all pairs of connected nodes are assigned to different clusters;&lt;/li&gt;
  &lt;li&gt;\(\mathcal{L}_c = -1\) when there are \(K\) disconnected components in the graph, and \(\mathbf{S}\) exactly maps the \(K\) components to the \(K\) clusters;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The figure below shows what these situations might look like. Note that both cases can only happen if \(\mathbf{S}\) is binary.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/loss_bounds.png&quot; alt=&quot;L_c bounds&quot; /&gt;&lt;/p&gt;

&lt;p&gt;However, because of the continuous relaxation, \(\mathcal{L}_c\) is non-convex and there are spurious minima that can be found by SGD.&lt;br /&gt;
For example, for \(K = 4\), the uniform assignment matrix&lt;/p&gt;

\[\mathbf{S}_i = (0.25, 0.25, 0.25, 0.25) \;\; \forall i,\]

&lt;p&gt;would cause the numerator and the denominator of \(\mathcal{L}_c\) to be equal, and the loss to be \(-1\).&lt;br /&gt;
A similar situation occurs when all nodes in the graph are assigned to the same cluster.&lt;/p&gt;

&lt;p&gt;This can be easily verified with Numpy:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Adjacency matrix
&lt;/span&gt;   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;  
   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt;               &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;  
   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt;               &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Degree matrix
&lt;/span&gt;   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Perfect cluster assignment
&lt;/span&gt;   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# All nodes uniformly distributed 
&lt;/span&gt;   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# All nodes in the same cluster 
&lt;/span&gt;   &lt;span class=&quot;p&quot;&gt;...:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt; 

&lt;span class=&quot;n&quot;&gt;In&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;orthogonality-loss&quot;&gt;Orthogonality loss&lt;/h3&gt;
&lt;p&gt;The second term, \(\mathcal{L}_o\), helps to avoid such degenerate minima of \(\mathcal{L}_c\) by encouraging the MLP to find clusters that are orthogonal between each other. We call this the &lt;strong&gt;orthogonality loss&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;In other words, \(\mathcal{L}_o\) encourages the MLP to “make a decision” about which nodes belong to which clusters, avoiding those degenerate solutions where \(\mathbf{S}\) assigns one \(K\)-th of a node to each cluster.&lt;/p&gt;

&lt;p&gt;Moreover, we can see that the perfect minimizer of \(\mathcal{L}_o\) is only reached if we have \(N \le K\) nodes, because in general, given a \(K\) dimensional vector space, we cannot find more than \(K\) mutually orthogonal vectors. 
The only way to minimize \(\mathcal{L}_o\) given \(N\) assignment vectors is, therefore, to distribute the nodes between the \(K\) clusters. This causes the MLP to avoid the other type of spurious minima of \(\mathcal{L}_c\), where all nodes are in a single cluster.&lt;/p&gt;

&lt;h2 id=&quot;interaction-of-the-two-losses&quot;&gt;Interaction of the two losses&lt;/h2&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/cora_mc_loss+nmi.png&quot; alt=&quot;Loss terms&quot; /&gt;&lt;/p&gt;

&lt;p&gt;We can see how the two loss terms interact with each other to find a good solution to the cluster assignment problem. 
The figure above shows the evolution of the unsupervised loss as the network is trained to cluster the nodes of Cora (plot on the left). We can see that as the network is trained, the normalized mutual information (NMI) between the cluster assignments and the true labels improves, meaning that the layer is learning to find meaningful clusters (plot on the right).&lt;/p&gt;

&lt;p&gt;Note how \(\mathcal{L}_c\) starts from a trivial assignment (-1) due to the random initialization and then moves away from the spurious minima as the orthogonality loss forces the MLP towards more sensible solutions.&lt;/p&gt;

&lt;h3 id=&quot;pooled-graph&quot;&gt;Pooled graph&lt;/h3&gt;
&lt;p&gt;As a further consideration, we can take a closer look at the pooled adjacency matrix \(\mathbf{A}^{pool}\).  &lt;br /&gt;
First of all, we can see that it is a \(K \times K\) matrix that contains the number of links connecting each cluster. For example, the entry \(\mathbf{A}^{pool}_{1,\;2}\) contains the number of links between the nodes in cluster 1 and cluster 2. 
We can also see that the trace of \(\mathbf{A}^{pool}\) is being maximized in \(\mathcal{L}_c\). Therefore, we can expect the diagonal elements \(\mathbf{A}^{pool}_{i,\;i}\) to be much larger than the other entries of \(\mathbf{A}^{pool}\).&lt;/p&gt;

&lt;p&gt;For this reason, \(\mathbf{A}^{pool}\) will represent a graph with very strong self-loops, and the message-passing layers after pooling will have a hard time propagating information on the graph (because the self-loops will keep sending the information of a node back onto itself, and not its neighbors).&lt;/p&gt;

&lt;p&gt;To address this problem, a solution is to remove the diagonal of \(\mathbf{A}^{pool}\) and renormalize the matrix by its degree, before giving it as output of the pooling layer:&lt;/p&gt;

\[\hat{\mathbf{A}} =  \mathbf{A}^{pool} - \mathbf{I}_K \cdot diag(\mathbf{A}^{pool}); \;\; \tilde{\mathbf{A}}^{pool} = \hat{\mathbf{D}}^{-\frac{1}{2}} \hat{\mathbf{A}} \hat{\mathbf{D}}^{-\frac{1}{2}}\]

&lt;p&gt;In the paper, we combined minCUT with message-passing layers that have a built-in skip connection, in order to bring each node’s information forward (e.g., Spektral’s &lt;a href=&quot;https://danielegrattarola.github.io/spektral/layers/convolution/#graphconvskip&quot;&gt;GraphConvSkip&lt;/a&gt;). 
However, if your GNN is based on the &lt;a href=&quot;https://danielegrattarola.github.io/spektral/layers/convolution/#graphconv&quot;&gt;graph convolutional networks (GCN)&lt;/a&gt; of &lt;a href=&quot;https://arxiv.org/abs/1609.02907&quot;&gt;Kipf &amp;amp; Welling&lt;/a&gt;, you may want to manually add the self-loops back after pooling.&lt;/p&gt;

&lt;h3 id=&quot;notes-on-gradient-flow&quot;&gt;Notes on gradient flow&lt;/h3&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/mincut_layer.png&quot; alt=&quot;mincut scheme&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The unsupervised loss \(\mathcal{L}_u\) can be optimized on its own, adapting the weights of the MLP to compute an \(\mathbf{S}\) that solves the minCUT problem under the orthogonality constraint.&lt;/p&gt;

&lt;p&gt;However, given the multiplicative interaction between \(\mathbf{S}\) and \(\mathbf{X}\), the gradient of the task-specific loss (i.e., whatever the GNN is being trained to do) can flow through the MLP. We can see in the picture above how there is a path going from the input \(\mathbf{X}^{(t+1)}\) to the output \(\mathbf{X}_{\textrm{pool}}^{(t+1)}\), directly passing through the MLP.&lt;/p&gt;

&lt;p&gt;This means that the overall solution found by the GNN will keep into account both the graph structure (to solve minCUT) and the final task.&lt;/p&gt;

&lt;h2 id=&quot;code&quot;&gt;Code&lt;/h2&gt;

&lt;p&gt;Implementing minCUT in TensorFlow is fairly straightforward. Let’s start from some setup:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;  &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tensorflow&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;
  &lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tensorflow.keras.layers&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Adjacency matrix (N x N)
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Node features (N x F)
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;n_clusters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Number of clusters to find with minCUT
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;First, the layer computes the cluster assignment matrix &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;S&lt;/code&gt; by applying a softmax MLP to the node features:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;H&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;activation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;relu&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_clusters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;activation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;softmax&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;H&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Cluster assignment matrix
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The cut loss is then implemented as:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Cut loss
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;D_pooled&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;den&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trace&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D_pooled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;mincut_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;den&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;And the orthogonality loss is implemented as:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Orthogonality loss
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;St_S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;I_S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eye&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_clusters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;ortho_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;St_S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;St_S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;I_S&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;I_S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Finally, the full unsupervised loss of the layer is obtained as the sum of the two auxiliary losses:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;total_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mincut_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ortho_loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The actual pooling step is simply implemented as a simple multiplication of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;S&lt;/code&gt; with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;A&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt;, then we zero-out the diagonal of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;A_pool&lt;/code&gt; and re-normalize the matrix. Since we already computed &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;A_pool&lt;/code&gt; for the numerator of \(\mathcal{L}_c\), we only need to do:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Pooling node features
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;S&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Zeroing out the diagonal
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linalg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;set_diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Remove diagonal
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Normalizing A_pool
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;D_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-12&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Add epsilon to avoid division by 0
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Wrap this up in a layer, and use the layer in a GNN. Done.&lt;/p&gt;

&lt;p&gt;You can find minCUT pooling implementations both in &lt;a href=&quot;https://danielegrattarola.github.io/spektral/layers/pooling/#mincutpool&quot;&gt;Spektral&lt;/a&gt; and &lt;a href=&quot;https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#module-torch_geometric.nn.dense.mincut_pool&quot;&gt;Pytorch Geometric&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;experiments&quot;&gt;Experiments&lt;/h2&gt;

&lt;h3 id=&quot;unsupervised-clustering&quot;&gt;Unsupervised clustering&lt;/h3&gt;
&lt;p&gt;Because the core of MinCutPool is an unsupervised loss that does not require labeled data in order to be minimized, we can optimize \(\mathcal{L}_u\) on its own to test the clustering ability of minCUT.&lt;/p&gt;

&lt;p&gt;A good first test is to check whether the layer is able to cluster a grid (the size of the clusters should be the same) and to isolate communities in a network. 
We see in the figure below that minCUT was able to do this perfectly.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/regular_clustering.png&quot; alt=&quot;Clustering with minCUT pooling&quot; /&gt;&lt;/p&gt;

&lt;p&gt;To make things more interesting, we can also test minCUT on the task of graph-based image segmentation. We can build a &lt;a href=&quot;https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_rag.html&quot;&gt;region adjacency graph&lt;/a&gt; from a natural image, and cluster its nodes in order to see if regions with similar colors are clustered together. &lt;br /&gt;
The results look nice, and remember that this was obtained by only optimizing \(\mathcal{L}_u\)!&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/horses.png&quot; alt=&quot;Horse segmentation with minCUT pooling&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Finally, we also checked the clustering abilities of MinCutPool on the popular citations datasets: Cora, Citeseer, and Pubmed. 
As mentioned before, we used the NMI score to see whether the layer was clustering together nodes of the same class. Note that the layer did not have access to the labels during training.&lt;/p&gt;

&lt;p&gt;You can check &lt;a href=&quot;https://arxiv.org/abs/1907.00481&quot;&gt;the paper&lt;/a&gt; to see how minCUT fared in comparison to other methods, but in short: it did well, sometimes by a full order of magnitude better than other methods.&lt;/p&gt;

&lt;h3 id=&quot;autoencoder&quot;&gt;Autoencoder&lt;/h3&gt;
&lt;p&gt;Another interesting unsupervised test that we did was to check how much information is preserved in the coarsened graph after pooling.
To do this, we built a simple graph autoencoder with the structure pictured below:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/ae.png&quot; alt=&quot;unsupervised reconstruction with AE&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The “Unpool” layer is simply obtained by transposing the same \(\mathbf{S}\) found by minCUT, in order to upscale the graph instead of downscaling it:&lt;/p&gt;

\[\mathbf{A}^\text{unpool} = \mathbf{S} \mathbf{A}^\text{pool} \mathbf{S}^T; \;\; \mathbf{X}^\text{unpool} = \mathbf{S}\mathbf{X}^\text{pool}.\]

&lt;p&gt;We tested the graph AE on some very regular graphs that should have been easy to reconstruct after pooling. Surprisingly, this turned out to be a difficult problem for some pooling layers from the GNN literature. MinCUT, on the other hand, was able to defend itself quite nicely.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/images/2019-07-25/reconstructions.png&quot; alt=&quot;unsupervised reconstruction with AE&quot; /&gt;&lt;/p&gt;

&lt;h3 id=&quot;supervised-inductive-tasks&quot;&gt;Supervised inductive tasks&lt;/h3&gt;

&lt;p&gt;Finally, we tested whether minCUT provides an improvement on the usual graph classification and graph regression tasks. &lt;br /&gt;
We picked a fixed GNN architecture and tested several pooling strategies by swapping the pooling layers in the network.&lt;/p&gt;

&lt;p&gt;The dataset that we used were:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;a href=&quot;https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets&quot;&gt;The Benchmark Data Sets for Graph Kernels&lt;/a&gt;;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://github.com/FilippoMB/Benchmark_dataset_for_graph_classification&quot;&gt;A synthetic dataset created by F. M. Bianchi to test GNNs&lt;/a&gt;;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;http://quantum-machine.org/datasets/&quot;&gt;The QM9 dataset for the prediction of chemical properties of molecules&lt;/a&gt;.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;I’m not gonna report the comparisons with other methods, but I will highlight an interesting sanity check that we performed in order to see whether using GNNs and graph pooling even made sense at all.&lt;/p&gt;

&lt;p&gt;Among the various methods that we tested, we also included:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;A simple MLP which did not exploit the relational information carried by the graphs;&lt;/li&gt;
  &lt;li&gt;The same GNN architecture without pooling layers.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;We were once again surprised to see that, while minCUT yielded a consistent improvement over such simple baselines, other pooling methods did not.&lt;/p&gt;

&lt;h2 id=&quot;conclusions&quot;&gt;Conclusions&lt;/h2&gt;

&lt;p&gt;Working on minCUT pooling was an interesting experience that deepened my understanding of GNNs, and allowed me to see what is really necessary for a GNN to work.&lt;/p&gt;

&lt;p&gt;We have put the paper &lt;a href=&quot;https://arxiv.org/abs/1907.00481&quot;&gt;on arXiv&lt;/a&gt;, and you can check the official implementations of the method in &lt;a href=&quot;https://danielegrattarola.github.io/spektral/layers/pooling/#mincutpool&quot;&gt;Spektral&lt;/a&gt; and &lt;a href=&quot;https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#module-torch_geometric.nn.dense.mincut_pool&quot;&gt;Pytorch Geometric&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;If you want to use MinCutPool in your own work, you can cite us with:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;@article{bianchi2019mincut,
  title={Spectral Clustering with Graph Neural Networks for Graph Pooling},
  author={Filippo Maria Bianchi and Daniele Grattarola and Cesare Alippi},
  booktitle={Proceedings of the 37th International Conference on Machine learning (ICML)},
  year={2020}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Cheers!&lt;/p&gt;
</description>
        <pubDate>Thu, 25 Jul 2019 00:00:00 +0000</pubDate>
        
        <link>/posts/2019-07-25/mincut-pooling.html</link>
          
        
            <category>AI</category>
        
            <category>GNN</category>
        
            <category>pooling</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
    <item>
        <title>Detecting Hostility from Skeletal Graphs Using Non-Euclidean Embeddings</title>
        <description>&lt;p&gt;The first paper on which I worked during my PhD is about &lt;a href=&quot;https://arxiv.org/abs/1805.06299&quot;&gt;detecting changes in sequences of graphs using non-Euclidean geometry and adversarial autoencoders&lt;/a&gt;. As a real-world application of the method presented in the paper, we showed that we could detect epileptic seizures in the brain, by monitoring a stream of functional connectivity brain networks.&lt;/p&gt;

&lt;p&gt;In general, the methodology presented in the paper can work for any data that:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;can be represented as graphs;&lt;/li&gt;
  &lt;li&gt;has a temporal dimension;&lt;/li&gt;
  &lt;li&gt;has a change that you want to identify somewhere along the stream of data;&lt;/li&gt;
  &lt;li&gt;has i.i.d. samples.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;There are &lt;a href=&quot;https://icon.colorado.edu/#!/networks&quot;&gt;a lot&lt;/a&gt; of temporal networks that can be found in the wild, but not many datasets respect all the requirements at the same time. What’s more, many public datasets have very little samples along the temporal axis.  &lt;!--more--&gt;
Recently, however, I was looking for some nice graph classification dataset on which to test &lt;a href=&quot;https://danielegrattarola.github.io/spektral&quot;&gt;Spektral&lt;/a&gt;, and I stumbled upon the &lt;a href=&quot;http://rose1.ntu.edu.sg/datasets/actionrecognition.asp&quot;&gt;NTU RGB+D&lt;/a&gt; dataset released by the Nanyang Technological University of Singapore.&lt;br /&gt;
The dataset consists of about 60 thousand video clips of people performing everyday actions, including mutual actions and some health-related ones. The reason why I found this dataset is that it contains skeletal annotations for each frame of each video clip, meaning lots and lots of graphs that &lt;a href=&quot;https://arxiv.org/abs/1801.07455&quot;&gt;can be used for graph classification&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;ntu-rgbd-for-change-detection&quot;&gt;NTU RGB+D for change detection&lt;/h2&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-04-13/graphs.svg&quot; alt=&quot;graphs&quot; title=&quot;Figure 1: examples of hugging and punching graphs.&quot; class=&quot;threeq-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;While reading through the website, however, I realized that this dataset could actually be a good playground for our change detection methodology as well, because it respects almost all requirements:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;it has graphs;&lt;/li&gt;
  &lt;li&gt;it has a temporal dimension;&lt;/li&gt;
  &lt;li&gt;it has classes, which can be easily converted to what we called the &lt;em&gt;regimes&lt;/em&gt; of our graph streams;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The fourth requirement of having i.i.d. samples is due to the nature of the change detection test that we adopted in the paper. The test is able to detect changes in stationarity of a stochastic process, which means that it can tell whether the samples coming from the process have been drawn from a different distribution than the one observed during training. &lt;br /&gt;
In order to do so, the test needs to estimate whether a window of observations from the process is significantly different than what observed in the nominal regime. This requires having i.i.d. samples in each window.&lt;/p&gt;

&lt;p&gt;By their very nature, however, the graphs in NTU RGB+D are definitely not i.i.d. (they would have been, had the subjects been recorded under a strobe light – dammit!).&lt;br /&gt;
There are several ways of converting a heavily autocorrelated signal to a stationary one, with the simplest one being randomizing along the time axis.
The piece-wise stationarity requirement is a very strong one, and we are looking into relaxing it, but for testing the method on NTU RGB+D we had to stick with it.&lt;/p&gt;

&lt;h2 id=&quot;setting&quot;&gt;Setting&lt;/h2&gt;

&lt;p&gt;Defining the change detection problem is easy: have a nominal regime of neutral or positive actions like walking, reading, taking a selfie, or being at the computer, and try to detect when the regime changes to a negative action like falling down, getting in fights with people, or feeling sick (there are at least 5 action classes of people acting hurt or sick in NTU RGB+D).&lt;/p&gt;

&lt;p&gt;Applications of this could include:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;monitoring children and elderly people when they are alone;&lt;/li&gt;
  &lt;li&gt;detecting violence in at-risk, crowded situations;&lt;/li&gt;
  &lt;li&gt;detecting when a driver is distracted;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In all of these situations, you might have a pretty good idea of what you &lt;em&gt;want&lt;/em&gt; to be happening at a given time, but have no way of knowing how things could go wrong.&lt;/p&gt;

&lt;p&gt;We chose the “hugging” action for the nominal, all-is-well regime, and we took the “punching/slapping” class to symbolize any unexpected, undesirable behaviour that deviates from our concept of nominal.
Then, we trained our adversarial autoencoder to represent points on an ensemble of constant-curvature manifolds, and we ran the change detection test. 
At this point, it would probably help if one was familiar with the details of &lt;a href=&quot;https://arxiv.org/abs/1805.06299&quot;&gt;the paper&lt;/a&gt;. In short, what we do is:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;take an adversarial graph autoencoder (AAE);&lt;/li&gt;
  &lt;li&gt;train the AAE on the nominal samples that you have at training time;&lt;/li&gt;
  &lt;li&gt;impose a geometric regularization onto the latent space of the AAE, so that the embeddings will lie on a Riemannian constant-curvature manifold (CCM).&lt;br /&gt;
This happens in one of two ways:
    &lt;ol&gt;
      &lt;li&gt;use a prior distribution with support on the CCM to train the AAE;&lt;/li&gt;
      &lt;li&gt;make the encoder maximise the membership of its embeddings to the CCM (this is the one we use for this experiment);&lt;/li&gt;
    &lt;/ol&gt;
  &lt;/li&gt;
  &lt;li&gt;use the trained AAE to represent incoming graphs on the CCM;&lt;/li&gt;
  &lt;li&gt;run the change detection test on the CCM;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-04-13/embeddings.svg&quot; alt=&quot;embeddings&quot; title=&quot;Figure 2: embeddings produced by the AAE on the three different CCMs. Blue for hugging, orange for punching.&quot; class=&quot;full-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;This procedure can be adapted to learn a representation on more than one CCM at a time, by having parallel latent spaces for the AAE. This worked pretty well in the paper, so we tried the same here. 
We also chose one of the two types of change detection tests that we introduced in the paper, namely the one we called &lt;em&gt;Riemannian&lt;/em&gt;, because it gave us the best results on the seizure detection problem.&lt;/p&gt;

&lt;h2 id=&quot;results&quot;&gt;Results&lt;/h2&gt;

&lt;p&gt;Running the whole method on the stream of graphs gave us very nice results. We were able to recognize the change from friendly to violent interactions in most experiments, although sometimes the autoencoder failed to capture the differences between the two regimes (and consequently, the CDT couldn’t pick up the change).&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://danielegrattarola.github.io/images/2019-04-13/accumulator.svg&quot; alt=&quot;accumulator&quot; title=&quot;Figure 3: accumulators of R-CDT (see the paper) for the three CCMs. The change is marked with the red line, the decision threshold with the green line. &quot; class=&quot;full-width&quot; /&gt;&lt;/p&gt;

&lt;p&gt;An interesting thing that we observed is that when using an ensemble of three different geometries, namely spherical, hyperbolic, and Euclidean, the change would only show up in the spherical CCM. 
This was a consistent result that gave us yet another confirmation of two things:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;assuming Euclidean geometry for the latent space is not always a good idea;&lt;/li&gt;
  &lt;li&gt;our idea of learning a representation on multiple CCMs at the same time worked as expected. Originally, we suggested this trick to potential adopters of our CDT methodology, in order to not having to guess the best geometry for the representation. Now, we have the confirmation that it is indeed a good idea, because the AAE will choose the best geometry for the task on its own.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Figure 2 above (hover over the images to see the captions) shows the embeddings produced by the encoder on the test stream of graphs. Figure 3 shows the three &lt;em&gt;accumulators&lt;/em&gt; used in the change detection test to decide whether or not to raise an alarm indicating that a change occurred. 
In both pictures, the decision for raising an alarm is informed almost exclusively by the spherical CCM.&lt;/p&gt;

&lt;h2 id=&quot;conclusions&quot;&gt;Conclusions&lt;/h2&gt;

&lt;p&gt;That’s all, folks!&lt;br /&gt;
This was a pretty little experiment to run, and it gave us further insights into the world of non-Euclidean neural networks. We have actually &lt;a href=&quot;https://arxiv.org/abs/1805.06299&quot;&gt;updated the paper&lt;/a&gt; with the findings of this new experiment, and you can also try and play with our algorithm using the &lt;a href=&quot;https://github.com/danielegrattarola/cdt-ccm-aae&quot;&gt;code on Github&lt;/a&gt; (the code there is for the synthetic experiments of the paper, but you can adapt it to any dataset easily).&lt;/p&gt;

&lt;p&gt;If you want to mention our CDT strategy in your work, you can cite:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;@article{grattarola2018change,
  title={Change Detection in Graph Streams by Learning Graph Embeddings on Constant-Curvature Manifolds},
  author={Grattarola, Daniele and Zambon, Daniele and Livi, Lorenzo and Alippi, Cesare},
  journal={IEE Transactions on Neural Networks and Learning Systems},
  year={2019},
  doi={10.1109/TNNLS.2019.2927301}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Cheers!&lt;/p&gt;
</description>
        <pubDate>Sat, 13 Apr 2019 00:00:00 +0000</pubDate>
        
        <link>/posts/2019-04-13/hostility-detection.html</link>
          
        
            <category>AI</category>
        
            <category>experiment</category>
        
            <category>non-euclidean</category>
        
          
        
            <category>posts</category>
        
          
      </item>
    
  </channel>
</rss>
