# Copyright 2023 Quarkslab
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Belief propagation framework
Contains the whole belief propagation implementation.
"""
import logging
import math
import numpy as np
from typing import Any
from collections.abc import Generator
# local imports
from qbindiff.types import Positive, Ratio, RawMapping, SparseMatrix
[docs]
class BeliefMWM:
"""
Computes the optimal solution to the **Maxmimum Weight Matching problem**.
"""
def __init__(self, sim_matrix: SparseMatrix, epsilon: Positive):
"""
:param sim_matrix: similarity matrix (sparse numpy matrix)
:param epsilon: perturbation for algorithm convergence
"""
self.weights = sim_matrix.copy() #: The weights sparse matrix
self._shape = sim_matrix.shape
self._dtype = sim_matrix.dtype.type
self._init_messages()
self.scores: list[float] = [] #: Scores list
self.max_avg_score: float = 0.0 #: Current maximum average score
self.best_mapping: RawMapping = ([], []) #: Current best mapping
self.best_marginals = None #: Current associated marginals as a SparseMatrix
self.epsilon = self._dtype(epsilon) #: Current epsilon
self._epsilonref = self.epsilon.copy()
def _init_messages(self) -> None:
"""
Initializes messages for the belief propagation phase
"""
#: Messages from node to factor targeting the node in the first graph. m(X[ii`] -> f[i])
self.msg_n2f = self.weights.copy()
#: Messages from node to factor targeting the node in the second graph. m(X[ii`] -> g[i`])
self.msg_n2g = self.weights.copy()
#: Messages from factor to node targeting the node in the first graph. m(f[i] -> X[ii`])
self.msg_f2n = self.weights.copy()
#: Messages from factor to node targeting the node in the second graph. m(g[i`] -> X[ii`])
self.msg_g2n = self.weights.copy()
#: Messages to the node, also known as max-marginal probability of the node. P(X[ii`])
self.marginals = self.weights.copy()
# The matching matrix between the two graphs. It is a mask that has to be applied
# to self.weights.data
self.matches_mask = np.zeros_like(self.weights.data, dtype=bool)
[docs]
def compute(self, maxiter: int = 1000) -> Generator[int, Any, Any]:
"""
Repeat the belief propagation round for a given number of iterations
:param maxiter: Maximum number of iterations for the algorithm
:return: generator that yield at each iteration
"""
for niter in range(1, maxiter + 1):
self._update_messages()
self._round_messages()
self._update_epsilon()
yield niter
if self._has_converged():
logging.info(f"[+] Converged after {niter} iterations")
return
logging.info(f"[+] Did not converged after {maxiter} iterations")
def _update_messages(self) -> None:
"""
Update the messages considering if it's better to start from the first graph or the second
"""
if self._shape[0] <= self._shape[1]:
self._update_messages_primary()
else:
self._update_messages_secondary()
def _update_messages_primary(self) -> None:
"""
Update messages starting from the first graph
"""
# Update messages from node to f
self.msg_n2f.data[:] = self.weights.data
self._update_factor_g_messages()
self.msg_n2f.data += self.msg_g2n.data
self.marginals.data[:] = self.msg_n2f.data
# Update messages from node to g
self.msg_n2g.data[:] = self.weights.data
self._update_factor_f_messages()
self.msg_n2g.data += self.msg_f2n.data
self.marginals.data += self.msg_f2n.data
def _update_messages_secondary(self) -> None:
"""
Update messages starting from the second graph
"""
# Update messages from node to g
self.msg_n2g.data[:] = self.weights.data
self._update_factor_f_messages()
self.msg_n2g.data += self.msg_f2n.data
self.marginals.data[:] = self.msg_n2g.data
# Update messages from node to f
self.msg_n2f.data[:] = self.weights.data
self._update_factor_g_messages()
self.msg_n2f.data += self.msg_g2n.data
self.marginals.data += self.msg_g2n.data
def _update_factor_msg(self, messages) -> None:
"""
Update the messages from factor to node. It is done in-place.
:param messages: messages to update
"""
if len(messages) > 1:
arg2, arg1 = np.argpartition(messages, -2)[-2:]
max2, max1 = np.maximum(0, messages[[arg2, arg1]], dtype=self._dtype)
messages[:] = -max1 - self.epsilon
messages[arg1] = -max2
else:
messages[:] = self._dtype(0)
def _update_factor_g_messages(self) -> None:
"""
Update all the messages from factor g to node
"""
# Use the csc (compressed sparse column) format for efficiency
msg_n2g_csc = self.msg_n2g.tocsc()
self.msg_g2n = self.msg_g2n.tocsc()
for k in range(self._shape[1]):
# All the messages share the same sparse matrix structure, i.e. they all
# have the same indptr and the same indices arrays
# This lets us perform some optimizations
begin = msg_n2g_csc.indptr[k]
end = msg_n2g_csc.indptr[k + 1]
col = msg_n2g_csc.data[begin:end]
self._update_factor_msg(col)
self.msg_g2n.data[begin:end] = col
# Non optimized version
# ~ col = self.msg_n2g[:, k]
# ~ self.update_factor_msg(col.data)
# ~ self.msg_g2n[:, k] = col
# Restore the csr (compressed sparse row) format
self.msg_g2n = self.msg_g2n.tocsr()
def _update_factor_f_messages(self) -> None:
"""
Update all the messages from factor f to node
"""
for k in range(self._shape[0]):
# All the messages share the same sparse matrix structure, i.e. they all
# have the same indptr and the same indices arrays
# This lets us perform some optimizations
begin = self.msg_n2f.indptr[k]
end = self.msg_n2f.indptr[k + 1]
row = self.msg_n2f.data[begin:end]
self._update_factor_msg(row)
self.msg_f2n.data[begin:end] = row
# Non optimized version
# ~ row = self.msg_n2f[k]
# ~ self.update_factor_msg(row.data)
# ~ self.msg_f2n[k] = row
def _round_messages(self) -> None:
"""
Rounding phase
"""
self.matches_mask[:] = self.marginals.data > 0
self.scores.append(self.current_score)
def _update_epsilon(self) -> None:
"""
Epsilon phase
"""
avg_score = self.scores[-1] / max(self.matches_mask.sum(), 1)
if self.max_avg_score < avg_score:
self.best_mapping = self.current_mapping
self.best_marginals = self.marginals.copy()
self.max_avg_score = avg_score
if len(self.scores) >= 10:
self.epsilon = self._epsilonref
elif len(self.scores) >= 10:
self.epsilon *= 1.2
def _has_converged(self, window: int = 60, pattern_size: int = 15) -> bool:
"""
Decide whether the algorithm has converged.
The algorithm has converged if we can find the same pattern at least once by looking
at the last `window` elements of the scores. The pattern is a list composed of the
last `pattern_size` elements of the scores.
:param window: Number of the latest scores to consider when searching for the pattern
:param pattern_size: Size of the pattern
:return: True or False if the algorithm have converged
:rtype: bool
"""
scores = self.scores[: -window - 1 : -1]
if len(scores) < 2 * pattern_size:
return False
pattern = scores[:pattern_size]
for i in range(pattern_size, window - pattern_size + 1):
if pattern == scores[i : i + pattern_size]:
return True
return False
@property
def current_mapping(self) -> RawMapping:
"""
Current mapping
"""
rows = (
np.searchsorted(self.weights.indptr, self.matches_mask.nonzero()[0], side="right") - 1
)
cols = self.weights.indices[self.matches_mask]
mask = np.intersect1d(
np.unique(rows, return_index=True)[1], np.unique(cols, return_index=True)[1]
)
return rows[mask], cols[mask]
@property
def current_score(self) -> float:
"""Current score"""
return self.weights.data[self.matches_mask].sum()
@property
def current_marginals(self) -> SparseMatrix:
"""
Current marginals in a sparse matrix
"""
curr_marginals = self.marginals.copy()
# The output of np.power might results in +inf, hence we need to clip those
# values. Here it is clipped to [0, 1e6] since 1e6/(1e6+1) ~ 0.999999
# Since those values are real probabilities it means that all the
# values > 99.9999% are the same.
curr_marginals.data[:] = [
x / (1 + x) for x in np.clip(np.power(math.e, curr_marginals.data), 0, 1e6)
]
return curr_marginals
[docs]
class BeliefQAP(BeliefMWM):
"""
Computes an approximate solution to the **Quadratic Assignment problem**.
"""
def __init__(
self,
sim_matrix: SparseMatrix,
squares: SparseMatrix,
tradeoff: Ratio,
epsilon: Positive,
):
"""
:param sim_matrix: similarity matrix (sparse numpy matrix)
:param squares: square matrix
:param tradeoff: trade-off value (close to 0 similarity, close to 1 squares (callgraph))
:param epsilon: perturbation value for convergence
"""
super(BeliefQAP, self).__init__(sim_matrix, epsilon)
if tradeoff == 1:
logging.warning("[+] meaningless tradeoff for NAQP")
squares -= squares
else:
self.weights.data *= 2 * tradeoff / (1 - tradeoff)
self._init_squares(squares)
@property
def current_score(self) -> float:
"""Current score of the solution"""
score = super(BeliefQAP, self).current_score
score += self.numsquares * 2
return score
@property
def numsquares(self) -> int:
"""Number of squares"""
squares = self.msg_h2n[self.matches_mask][:, self.matches_mask]
return (squares.sum() + squares.diagonal().sum()) / 2
def _init_squares(self, squares: SparseMatrix) -> None:
"""
Initializes the square matrix
:param squares: square matrix
"""
#: Messages from square factor to node. m(h[ii`jj`] -> X[ii`])
self.msg_h2n = squares.astype(self._dtype)
#: The additional weight matrix addressing the squares weights. W[ii`jj`]
self.weights_squares = self.msg_h2n.data.copy()
#: Number of squares (ii`, jj`) for each edge ii`
self.squares_per_edge = np.diff(squares.indptr)
def _update_messages_primary(self) -> None:
"""
Update messages starting from the first graph
"""
partial = self.weights.data.copy()
partial += self.msg_h2n.sum(1).getA1()
# Update messages from node to f
self.msg_n2f.data[:] = partial
self._update_factor_g_messages()
self.msg_n2f.data += self.msg_g2n.data
self.marginals.data[:] = self.msg_n2f.data
# Update messages fron node to g
self.msg_n2g.data[:] = partial
self._update_factor_f_messages()
self.msg_n2g.data += self.msg_f2n.data
self.marginals.data += self.msg_f2n.data
def _update_messages_secondary(self) -> None:
"""
Update messages starting from the second graph
"""
partial = self.weights.data.copy()
partial += self.msg_h2n.sum(1).getA1()
# Update messages from node to g
self.msg_n2g.data[:] = partial
self._update_factor_f_messages()
self.msg_n2g.data += self.msg_f2n.data
self.marginals.data[:] = self.msg_n2g.data
# Update messages from node to f
self.msg_n2f.data[:] = partial
self._update_factor_g_messages()
self.msg_n2f.data += self.msg_g2n.data
self.marginals.data += self.msg_g2n.data
def _round_messages(self) -> None:
"""
Rounding phase
"""
super(BeliefQAP, self)._round_messages()
self._update_square_factor_messages()
def _update_square_factor_messages(self) -> None:
"""
Update the messages denoted by
$$
m_{h_{ii\\prime jj\\prime} \\rightarrow{} X_{ii\\prime}}
$$
The formula is the following one :
$$
m_{h_{ii\\prime j j\\prime} \\xrightarrow{} X_{ii\\prime}} = \\text{clip} (w_{ii\\prime jj\\prime} + m_{X_{jj\\prime}\\rightarrow{} h_{ii\\prime j j\\prime}}) - \\text{clip}(m_{X_{jj\\prime} \\xrightarrow{} h_{ii\\prime j j\\prime}})
$$
where $$ \\text{clip}(x) = max(0, x) $$
"""
# partial is the message from node to square factor m(X[ii`] -> h[ii`jj`])
partial = self.msg_h2n
partial.data -= np.repeat(self.marginals.data, self.squares_per_edge)
np.clip(partial.data, 0, partial.data.max(initial=0), out=partial.data)
# transpose
partial = partial.T.tocsr()
self.msg_h2n.data[:] = self.weights_squares - partial.data
np.clip(
self.msg_h2n.data,
0,
self.msg_h2n.data.max(initial=0),
out=self.msg_h2n.data,
)