From 3143df35b75317d5ddd266160a1c2ac21a4e9183 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Mon, 3 Nov 2025 12:25:57 -0500 Subject: [PATCH 1/2] Added alpha beta and gamma values --- pytorch_msssim/ssim.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pytorch_msssim/ssim.py b/pytorch_msssim/ssim.py index 16380e2..a111cfb 100644 --- a/pytorch_msssim/ssim.py +++ b/pytorch_msssim/ssim.py @@ -60,7 +60,8 @@ def _ssim( data_range: float, win: Tensor, size_average: bool = True, - K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Tuple[float, float, float] = (1, 1, 1) ) -> Tuple[Tensor, Tensor]: r""" Calculate ssim index for X and Y @@ -80,6 +81,7 @@ def _ssim( C1 = (K1 * data_range) ** 2 C2 = (K2 * data_range) ** 2 + C3 = C2 / 2 win = win.to(X.device, dtype=X.dtype) @@ -93,9 +95,16 @@ def _ssim( sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) + sigma1 = sigma1_sq ** 0.5 + sigma2 = sigma2_sq ** 0.5 + + alpha, beta, gamma = alpha_beta_gamma + luminance = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) + contrast = (2 * sigma1 * sigma2 + C2) / (sigma1_sq + sigma2_sq + C2) + structure = (sigma12 + C3) / (sigma1 + sigma2 + C3) - cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 - ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map + cs_map = (contrast ** beta) * (structure ** gamma) + ssim_map = (luminance ** alpha) * cs_map ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) cs = torch.flatten(cs_map, 2).mean(-1) From 97c076aa93db147f7e92bd8f78e311cc7d3e3f2e Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Mon, 3 Nov 2025 12:54:20 -0500 Subject: [PATCH 2/2] Working SSIM with adjustable alpha, beta, gamma --- pytorch_msssim/ssim.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/pytorch_msssim/ssim.py b/pytorch_msssim/ssim.py index a111cfb..92272a1 100644 --- a/pytorch_msssim/ssim.py +++ b/pytorch_msssim/ssim.py @@ -61,7 +61,7 @@ def _ssim( win: Tensor, size_average: bool = True, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), - alpha_beta_gamma: Tuple[float, float, float] = (1, 1, 1) + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), ) -> Tuple[Tensor, Tensor]: r""" Calculate ssim index for X and Y @@ -98,11 +98,11 @@ def _ssim( sigma1 = sigma1_sq ** 0.5 sigma2 = sigma2_sq ** 0.5 - alpha, beta, gamma = alpha_beta_gamma luminance = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) contrast = (2 * sigma1 * sigma2 + C2) / (sigma1_sq + sigma2_sq + C2) - structure = (sigma12 + C3) / (sigma1 + sigma2 + C3) + structure = (sigma12 + C3) / (sigma1 * sigma2 + C3) + alpha, beta, gamma = alpha_beta_gamma cs_map = (contrast ** beta) * (structure ** gamma) ssim_map = (luminance ** alpha) * cs_map @@ -120,6 +120,7 @@ def ssim( win_sigma: float = 1.5, win: Optional[Tensor] = None, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), nonnegative_ssim: bool = False, ) -> Tensor: r""" interface of ssim @@ -132,6 +133,7 @@ def ssim( win_sigma: (float, optional): sigma of normal distribution win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu Returns: @@ -160,7 +162,7 @@ def ssim( win = _fspecial_gauss_1d(win_size, win_sigma) win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) - ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) + ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, alpha_beta_gamma=alpha_beta_gamma) if nonnegative_ssim: ssim_per_channel = torch.relu(ssim_per_channel) @@ -179,7 +181,9 @@ def ms_ssim( win_sigma: float = 1.5, win: Optional[Tensor] = None, weights: Optional[List[float]] = None, - K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), + ) -> Tensor: r""" interface of ms-ssim Args: @@ -192,6 +196,8 @@ def ms_ssim( win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma weights (list, optional): weights for different levels K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. + Returns: torch.Tensor: ms-ssim results """ @@ -234,7 +240,7 @@ def ms_ssim( levels = weights_tensor.shape[0] mcs = [] for i in range(levels): - ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) + ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K, alpha_beta_gamma=alpha_beta_gamma) if i < levels - 1: mcs.append(torch.relu(cs)) @@ -262,6 +268,7 @@ def __init__( channel: int = 3, spatial_dims: int = 2, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), nonnegative_ssim: bool = False, ) -> None: r""" class for ssim @@ -272,6 +279,7 @@ def __init__( win_sigma: (float, optional): sigma of normal distribution channel (int, optional): input channels (default: 3) K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. """ @@ -281,6 +289,7 @@ def __init__( self.size_average = size_average self.data_range = data_range self.K = K + self.alpha_beta_gamma = alpha_beta_gamma self.nonnegative_ssim = nonnegative_ssim def forward(self, X: Tensor, Y: Tensor) -> Tensor: @@ -291,6 +300,7 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor: size_average=self.size_average, win=self.win, K=self.K, + alpha_beta_gamma=self.alpha_beta_gamma, nonnegative_ssim=self.nonnegative_ssim, ) @@ -306,6 +316,7 @@ def __init__( spatial_dims: int = 2, weights: Optional[List[float]] = None, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), ) -> None: r""" class for ms-ssim Args: @@ -316,6 +327,7 @@ def __init__( channel (int, optional): input channels (default: 3) weights (list, optional): weights for different levels K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. """ super(MS_SSIM, self).__init__() @@ -325,6 +337,7 @@ def __init__( self.data_range = data_range self.weights = weights self.K = K + self.alpha_beta_gamma = alpha_beta_gamma def forward(self, X: Tensor, Y: Tensor) -> Tensor: return ms_ssim( @@ -335,4 +348,5 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor: win=self.win, weights=self.weights, K=self.K, + alpha_beta_gamma = self.alpha_beta_gamma )