Shortcuts

Source code for rl4co.models.zoo.symnco.losses

from einops import rearrange
from torch.nn.functional import cosine_similarity


[docs]def problem_symmetricity_loss(reward, log_likelihood, dim=1): """REINFORCE loss for problem symmetricity Baseline is the average reward for all augmented problems Corresponds to `L_ps` in the SymNCO paper """ num_augment = reward.shape[dim] if num_augment < 2: return 0 advantage = reward - reward.mean(dim=dim, keepdim=True) loss = -advantage * log_likelihood return loss.mean()
[docs]def solution_symmetricity_loss(reward, log_likelihood, dim=-1): """REINFORCE loss for solution symmetricity Baseline is the average reward for all start nodes Corresponds to `L_ss` in the SymNCO paper """ num_starts = reward.shape[dim] if num_starts < 2: return 0 advantage = reward - reward.mean(dim=dim, keepdim=True) loss = -advantage * log_likelihood return loss.mean()
[docs]def invariance_loss(proj_embed, num_augment): """Loss for invariant representation on projected nodes Corresponds to `L_inv` in the SymNCO paper """ pe = rearrange(proj_embed, "(b a) ... -> b a ...", a=num_augment) similarity = sum( [cosine_similarity(pe[:, 0], pe[:, i], dim=-1) for i in range(1, num_augment)] ) return similarity.mean()

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.