diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 5bb3aa656..0adac6b8c 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: env: - APP_VERSION: 2.3.7 + APP_VERSION: 2.3.8 jobs: diff --git a/CHANGELOG.md b/CHANGELOG.md index ae2b07811..aaf3bc30e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,15 @@ +# RVE 2.3.8 pre-release +### Added + - Open output folder button + - Cap output scale based on model scale +### Changed + - UI Tweaks +### Fixed + - ROCm showing cuda on front end. + - GMFSS not working +### Removed + - GIMM. + - locking the app to 1 instance. # RVE 2.3.7 ### Added: - PyTorch 2.9 diff --git a/README.md b/README.md index 73598bc36..b1ce8fa82 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![pypresence](https://img.shields.io/badge/using-pypresence-00bb88.svg?style=for-the-badge&logo=discord&logoWidth=20)](https://github.com/qwertyquerty/pypresence) ![license](https://img.shields.io/github/license/tntwise/real-video-enhancer) -![Version](https://img.shields.io/badge/Version-2.3.6-blue) +![Version](https://img.shields.io/badge/Version-2.3.8-blue) ![downloads_total](https://img.shields.io/github/downloads/tntwise/REAL-Video-Enhancer/total.svg?label=downloads%40total) Discord Shield @@ -40,7 +40,7 @@ # Introduction -REAL Video Enhancer is a redesigned and enhanced version of the original Rife ESRGAN App for Linux. This program offers convenient access to frame interpolation and upscaling functionalities on Windows, Linux and MacOS , and is an alternative to outdated software like Flowframes or enhancr. +REAL Video Enhancer is a redesigned and enhanced version of the original Rife ESRGAN App for Linux. This program offers convenient access to frame interpolation and upscaling functionalities on Windows, Linux and MacOS, and is an alternative to outdated software like Flowframes or enhancr.

@@ -128,7 +128,7 @@ git clone --recurse-submodules https://github.com/TNTwise/REAL-Video-Enhancer # Stable -git clone --recurse-submodules https://github.com/TNTwise/REAL-Video-Enhancer --branch 2.3.6 +git clone --recurse-submodules https://github.com/TNTwise/REAL-Video-Enhancer --branch 2.3.8 ``` # Building: @@ -161,6 +161,7 @@ python3 build.py --build BUILD_OPTION --copy_backend | Software Used | For | Link| |--|--|--| | FFmpeg | Multimedia framework for handling video, audio, and other media files | https://ffmpeg.org/ +| QT | GUI framework | https://qt.io/ | FFMpeg Builds | Pre-compiled builds of FFMpeg. | Windows/Linux: https://github.com/BtbN/FFmpeg-Builds, MacOS: https://github.com/eko5624/mpv-mac | PyTorch | Neural Network Inference (CUDA/ROCm/TensorRT) | https://pytorch.org/ | NCNN | Neural Network Inference (Vulkan) | https://github.com/tencent/ncnn diff --git a/REAL-Video-Enhancer.nsi b/REAL-Video-Enhancer.nsi index e3e3d4f94..cb17fe242 100644 --- a/REAL-Video-Enhancer.nsi +++ b/REAL-Video-Enhancer.nsi @@ -8,12 +8,12 @@ ; Custom defines !define NAME "REAL Video Enhancer" !define APPFILE "REAL-Video-Enhancer.exe" - !define VERSION "2.3.7" + !define VERSION "2.3.8" !define SLUG "${NAME} v${VERSION}" !define COMPANYNAME "TNTwise" !define VERSIONMAJOR 2 !define VERSIONMINOR 3 - !define VERSIONBUILD 7 + !define VERSIONBUILD 8 !define DISPLAYVERSION "${VERSIONMAJOR}.${VERSIONMINOR}.${VERSIONBUILD}" !define INSTALLSIZE 297000 diff --git a/REAL-Video-Enhancer.py b/REAL-Video-Enhancer.py index 83d34b081..e98961fb5 100644 --- a/REAL-Video-Enhancer.py +++ b/REAL-Video-Enhancer.py @@ -340,25 +340,27 @@ def switchToDownloadPage(self): self.animationHandler.fadeInAnimation(self.stackedWidget) def updateVideoGUIText(self): + self.settings.readSettings() if self.isVideoLoaded: upscaleModelName = self.upscaleModelComboBox.currentText() interpolateModelName = self.interpolateModelComboBox.currentText() interpolateTimes = self.getInterpolationMultiplier(interpolateModelName) scale = self.getUpscaleModelScale(upscaleModelName) - text = ( + new_bitrate = 8 if "10" not in self.settingsTab.in_pix_fmt else 10 + inputText = ( f"FPS: {round(self.videoFps, 0)} -> {round(self.videoFps * interpolateTimes, 0)}\n" + f"Resolution: {self.videoWidth}x{self.videoHeight} -> {self.videoWidth * scale}x{self.videoHeight * scale}\n" + f"Frame Count: {self.videoFrameCount} -> {int(round(self.videoFrameCount * interpolateTimes, 0))}\n" - + f"Bitrate: {self.videoBitrate}\n" - + f"Encoder: {self.videoEncoder}\n" - + f"Container: {self.videoContainer}\n" - + f"Color Space: {self.colorSpace}\n" - + f"Pixel Format: {self.pixelFMT}\n" - + f"HDR: {self.videoHDR}\n" - + f"Bit Depth: {self.videoBitDepth} bit\n" + + f"Encoder: {self.videoEncoder} -> {self.settings.settings['encoder']}\n" + + f"Container: {self.videoContainer} -> {self.settings.settings['video_container']}\n" + + f"Color Space: {self.colorSpace} -> {self.colorSpace}\n" + + f"Pixel Format: {self.settingsTab.in_pix_fmt} -> {self.settingsTab.out_pixel_fmt}\n" + + f"HDR: {self.videoHDR} -> {self.videoHDR if self.settings.settings['auto_hdr_mode'] == 'True' else 'False'}\n" + + f"Bit Depth: {self.videoBitDepth} bit -> {new_bitrate if self.settings.settings['auto_hdr_mode'] == 'True' else 8} bit\n" ) - self.videoInfoTextEdit.setFontPointSize(10) - self.videoInfoTextEdit.setText(text) + + self.inputVideoInfoTextEdit.setFontPointSize(10) + self.inputVideoInfoTextEdit.setText(inputText) def getInterpolationMultiplier(self, interpolateModelName): if interpolateModelName == "None" or not self.interpolateCheckBox.isChecked(): @@ -427,6 +429,7 @@ def updateVideoGUIDetails(self): isDeblur = self.deblurCheckBox.isChecked() isDenoise = self.denoiseCheckBox.isChecked() isDecompress = self.decompressCheckBox.isChecked() + self.interpolationContainer.setVisible(isInterpolate) self.interpolateContainer_2.setVisible(isInterpolate) self.deblurContainer.setVisible(isDeblur) @@ -446,7 +449,13 @@ def updateVideoGUIDetails(self): self.startTimeSpinBox.setMaximum(self.videoLength) self.endTimeSpinBox.setMaximum(self.videoLength) self.timeInVideoScrollBar.setMaximum(self.videoLength) - + if isUpscale and (self.upscaleModelComboBox.currentText() != "" or self.upscaleModelComboBox.currentText() != "None"): + try: + max_scale = totalModels[self.upscaleModelComboBox.currentText()][2] + self.upscaleScaleSpinBox.setMaximum(max_scale if max_scale > 0 else 4) + except KeyError: # idk why it does this, gui is shit tbh. + self.upscaleScaleSpinBox.setMaximum(4) + def getCurrentRenderOptions(self, input_file=None, output_path=None): interpolate = self.interpolateModelComboBox.currentText() upscale = self.upscaleModelComboBox.currentText() @@ -679,6 +688,8 @@ def disableProcessPage(self): child.setEnabled(False) for child in self.renderQueueTab.children(): child.setEnabled(False) + for child in self.encoderSettings.children(): + child.setEnabled(False) self.RenderedPreviewControlsContainer.setEnabled(False) self.scrollArea_4.setEnabled(True) self.scrollAreaWidgetContents_4.setEnabled(False) @@ -691,6 +702,8 @@ def enableProcessPage(self): child.setEnabled(True) for child in self.renderQueueTab.children(): child.setEnabled(True) + for child in self.encoderSettings.children(): + child.setEnabled(True) self.RenderedPreviewControlsContainer.setEnabled(True) self.scrollAreaWidgetContents_4.setEnabled(True) @@ -753,6 +766,7 @@ def loadVideo(self, inputFile, multi_file=False): self.outputFileText.setEnabled(True) self.outputFileSelectButton.setEnabled(True) + self.openOutputFolderButton.setEnabled(True) self.isVideoLoaded = True self.updateVideoGUIDetails() @@ -915,11 +929,11 @@ def main(): app.setStyle("Fusion") app.setPalette(Palette()) - if not "--unlock" in sys.argv: + """if not "--unlock" in sys.argv: lock_file = QLockFile(LOCKFILE) if not lock_file.tryLock(10): QMessageBox.warning(None, "Instance Running", "Another instance is already running.") - sys.exit(0) + sys.exit(0)""" # setting the pallette window = MainWindow() @@ -933,7 +947,6 @@ def main(): """ custom command args --debug: runs the app in debug mode ---unlock: allows more than one instance to be launched --fullscreen: runs the app in fullscreen --swap-flatpak-checks: swaps the flatpak checks, ex if the app is running in flatpak, it will run as if it is not """ diff --git a/backend/rve-backend.py b/backend/rve-backend.py index ad9797e00..656b22faa 100644 --- a/backend/rve-backend.py +++ b/backend/rve-backend.py @@ -2,6 +2,8 @@ import argparse import sys from src.version import __version__ +from src.utils.Util import log + class HandleApplication: def __init__(self): @@ -14,7 +16,7 @@ def __init__(self): """from pyinstrument import Profiler profiler = Profiler() profiler.start()""" - + from src.utils.VideoInfo import OpenCVInfo, print_video_info if self.args.print_video_info: @@ -35,6 +37,13 @@ def __init__(self): download_ffmpeg() if not self.batchProcessing(): + buffer_str = "=" * len(str(sys.argv[0])) + log(buffer_str, False) + log("RVE Backend Version: " + __version__, False) + log(buffer_str, False) + log("CLI Arguments: ", False) + log(str(sys.argv), False) + log(buffer_str, False) self.renderVideo() else: diff --git a/backend/src/FFmpegBuffers.py b/backend/src/FFmpegBuffers.py index d8eb3e3d9..8be61df96 100644 --- a/backend/src/FFmpegBuffers.py +++ b/backend/src/FFmpegBuffers.py @@ -1,9 +1,9 @@ import queue +import sys from abc import ABC, abstractmethod import os import subprocess import queue -import sys import time import cv2 import numpy as np @@ -36,7 +36,7 @@ def __init__(self, inputFile, width, height, start_time, end_time, borderX, bord self.color_transfer = color_transfer self.input_pixel_format = input_pixel_format self.yuv420pMOD = self.input_pixel_format == "yuv420p" and not self.hdr_mode - + #self.yuv420pMOD = False if self.hdr_mode: self.inputFrameChunkSize = width * height * 6 else: @@ -44,7 +44,8 @@ def __init__(self, inputFile, width, height, start_time, end_time, borderX, bord self.inputFrameChunkSize = width * height * 3 // 2 else: self.inputFrameChunkSize = width * height * 3 - + command = self.command() + log("FFMPEG READ COMMAND: " + str(command)) self.readProcess = subprocess_popen_without_terminal( self.command(), stdout=subprocess.PIPE, @@ -53,15 +54,14 @@ def __init__(self, inputFile, width, height, start_time, end_time, borderX, bord self.readQueue = queue.Queue(maxsize=25) def command(self): - log("Generating FFmpeg READ command...") command = [ f"{FFMPEG_PATH}", "-i", f"{self.inputFile}", ] - - filter_string = f"crop={self.width}:{self.height}:{self.borderX}:{self.borderY},scale=w=iw*sar:h=ih" # fix dar != sar + + filter_string = f"crop={self.width}:{self.height}:{self.borderX}:{self.borderY},scale=w=iw*sar:h=ih" #+ ":in_range=limited:out_range=full,format=yuv420p" if self.yuv420pMOD == "yuv420p" else "" # fix dar != sar #if not self.hdr_mode: # if self.input_pixel_format == "yuv420p": # filter_string += ":in_range=tv:out_range=pc" # color shifts a smidgen but helps with artifacts when converting yuv to raw @@ -182,7 +182,6 @@ def __init__( self.color_space = color_space self.color_primaries = color_primaries self.color_transfer = color_transfer - log(f"FFmpegWrite parameters: inputFile={inputFile}, outputFile={outputFile}, width={width}, height={height}, start_time={start_time}, end_time={end_time}, fps={fps}, crf={crf}, audio_bitrate={audio_bitrate}, pixelFormat={pixelFormat}, overwrite={overwrite}, custom_encoder={custom_encoder}, benchmark={benchmark}, slowmo_mode={slowmo_mode}, upscaleTimes={upscaleTimes}, interpolateFactor={interpolateFactor}, ceilInterpolateFactor={ceilInterpolateFactor}, video_encoder={video_encoder}, audio_encoder={audio_encoder}, subtitle_encoder={subtitle_encoder}, hdr_mode={hdr_mode}, mpv_output={mpv_output}, merge_subtitles={merge_subtitles}") self.outputFPS = ( (self.fps * self.interpolateFactor) if not self.slowmo_mode @@ -190,9 +189,10 @@ def __init__( ) self.ffmpeg_log = open(FFMPEG_LOG_FILE, "w", encoding='utf-8') try: - + command = self.command() + log("\nFFMPEG WRITE COMMAND: " + str(command) + "\n") self.writeProcess = subprocess_popen_without_terminal( - self.command(), + command, stdin=subprocess.PIPE, stderr=self.ffmpeg_log, stdout=subprocess.PIPE if self.mpv_output else self.ffmpeg_log, @@ -386,7 +386,7 @@ def command(self): "-", ] - log("FFMPEG WRITE COMMAND: " + str(command)) + return command def get_num_frames_rendered(self): @@ -440,7 +440,7 @@ def merge_subtitles(self): log("Benchmark mode enabled, skipping subtitle merge.") return - temp_output = self.outputFile + ".temp.mkv" + temp_output = self.outputFile + "-" + str(os.getpid()) + "-temp.mkv" os.rename(self.outputFile, temp_output) command = [ @@ -462,9 +462,6 @@ def merge_subtitles(self): self.outputFile, ] - if self.overwrite: - command.append("-y") - log("Merging subtitles with command: " + " ".join(command)) try: @@ -472,6 +469,7 @@ def merge_subtitles(self): if result.returncode != 0: log("Failed to merge subtitles. FFmpeg error:") log(result.stderr.decode()) + os.remove(self.outputFile) # Remove incomplete output file os.rename(temp_output, self.outputFile) # Restore original file return os.remove(temp_output) diff --git a/backend/src/RenderVideo.py b/backend/src/RenderVideo.py index 734054a10..b8e87921a 100644 --- a/backend/src/RenderVideo.py +++ b/backend/src/RenderVideo.py @@ -11,7 +11,7 @@ from .FFmpeg import InformationWriteOut from .utils.Encoders import EncoderSettings from .utils.SceneDetect import SceneDetect -from .utils.Util import log, bytesToImg, resize_image_bytes +from .utils.Util import log, resize_image_bytes from .utils.BorderDetect import BorderDetect from .utils.VideoInfo import OpenCVInfo import numpy as np @@ -215,7 +215,7 @@ def __init__( log(f"Interpolate Factor: {self.interpolateFactor}") log(f"Total Output Frames: {self.totalOutputFrames}") log("Model Scale: " + str(self.modelScale)) - print("HDR Mode: " + str(hdr_mode), file=sys.stderr) + log("HDR Mode: " + str(hdr_mode)) self.readBuffer = FFmpegRead( # input width inputFile=inputFile, diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py b/backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py deleted file mode 100644 index b4b220334..000000000 --- a/backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py +++ /dev/null @@ -1,107 +0,0 @@ -from gimmvfi_r import GIMMVFI_R - -import torch -import torch.nn.functional as F -import os -from PIL import Image -import numpy as np - - -class InputPadder: - """Pads images such that dimensions are divisible by divisor""" - - def __init__(self, dims, divisor=16): - self.ht, self.wd = dims[-2:] - pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor - pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor - self._pad = [ - pad_wd // 2, - pad_wd - pad_wd // 2, - pad_ht // 2, - pad_ht - pad_ht // 2, - ] - - def pad(self, *inputs): - if len(inputs) == 1: - return F.pad(inputs[0], self._pad, mode="replicate") - else: - return [F.pad(x, self._pad, mode="replicate") for x in inputs] - - def unpad(self, *inputs): - if len(inputs) == 1: - return self._unpad(inputs[0]) - else: - return [self._unpad(x) for x in inputs] - - def _unpad(self, x): - ht, wd = x.shape[-2:] - c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] - return x[..., c[0] : c[1], c[2] : c[3]] - - -device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "xpu" if torch.xpu.is_available() else "cpu") -model = GIMMVFI_R("GIMMVFI_RAFT.pth").to(device) - - -def convert(param): - return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} - - -ckpt = torch.load("gimmvfi_r_arb_lpips.pt", map_location="cpu") -raft = torch.load("raft-things.pth", map_location="cpu") -combined_state_dict = {"gimmvfi_r": ckpt["state_dict"], "raft": convert(raft)} -torch.save(combined_state_dict, "GIMMVFI_RAFT.pth") -model.load_state_dict(combined_state_dict["gimmvfi_r"]) - -images = [] - - -def load_image(img_path): - img = Image.open(img_path) - raw_img = np.array(img.convert("RGB")) - img = torch.from_numpy(raw_img.copy()).permute(2, 0, 1) / 255.0 - return img.to(torch.float).unsqueeze(0) - - -img_path0 = "0001.png" -img_path2 = "0004.png" -# prepare data b,c,h,w -I0 = load_image(img_path0) -I2 = load_image(img_path2) -padder = InputPadder(I0.shape, 32) -I0, I2 = padder.pad(I0, I2) -xs = torch.cat((I0.unsqueeze(2), I2.unsqueeze(2)), dim=2).to(device, non_blocking=True) -print(I0.shape) -print(xs.shape) -model.eval() -batch_size = xs.shape[0] -s_shape = xs.shape[-2:] - -model.zero_grad() -ds_factor = 0.5 -interp_factor = 4 -with torch.no_grad(): - coord_inputs = [ - ( - model.sample_coord_input( - batch_size, - s_shape, - [1 / interp_factor * i], - device=xs.device, - upsample_ratio=ds_factor, - ), - None, - ) - for i in range(1, interp_factor) - ] - timesteps = [ - i - * 1 - / interp_factor - * torch.ones(xs.shape[0]).to(xs.device).to(torch.float).reshape(-1, 1, 1, 1) - for i in range(1, interp_factor) - ] - output = model(xs, coord_inputs[2], timestep=timesteps[2], ds_factor=ds_factor) - # out_flowts = [padder.unpad(f) for f in all_outputs["flowt"]] - - images.append((output.detach().cpu().numpy()).astype(np.uint8)) diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/LICENSE.txt b/backend/src/pytorch/InterpolateArchs/GIMM/LICENSE.txt deleted file mode 100644 index ff369b4ed..000000000 --- a/backend/src/pytorch/InterpolateArchs/GIMM/LICENSE.txt +++ /dev/null @@ -1,13 +0,0 @@ -S-Lab License 1.0 -Copyright 2024 S-Lab - -Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/gimmvfi_r.py b/backend/src/pytorch/InterpolateArchs/GIMM/gimmvfi_r.py deleted file mode 100644 index 0b69e7b62..000000000 --- a/backend/src/pytorch/InterpolateArchs/GIMM/gimmvfi_r.py +++ /dev/null @@ -1,527 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# amt: https://github.com/MCG-NKU/AMT -# motif: https://github.com/sichun233746/MoTIF -# ginr-ipc: https://github.com/kakaobrain/ginr-ipc -# -------------------------------------------------------- - - -from tabnanny import check -import torch -import torch.nn as nn -import torch.nn.functional as F -from ....constants import HAS_SYSTEM_CUDA - -try: - from .raft import ( - normalize_flow, - unnormalize_flow, - warp, - resize, - build_coord, - multi_flow_combine, - NewInitDecoder, - NewMultiFlowDecoder, - BasicUpdateBlock, - LateralBlock, - HypoNet, - CoordSampler3D, - ) -except ImportError: - from raft import ( - normalize_flow, - unnormalize_flow, - warp, - resize, - build_coord, - multi_flow_combine, - NewInitDecoder, - NewMultiFlowDecoder, - BasicUpdateBlock, - LateralBlock, - HypoNet, - CoordSampler3D, - ) -try: - from .raftarch import RAFT, BidirCorrBlock -except ImportError: - from raftarch import RAFT, BidirCorrBlock -try: - if HAS_SYSTEM_CUDA: - from ..util.softsplat_cupy import softsplat - else: - from ..util.softsplat_torch import softsplat -except ImportError: - from softsplat_torch import softsplat - - -class GIMMVFI_R(nn.Module): - def __init__(self, model_path, width=1920, height=1080): - super().__init__() - self.raft_iter = 20 - self.width = width - self.height = height - - ######### Encoder and Decoder Settings ######### - model = RAFT() - ckpt = torch.load(model_path) - model.load_state_dict(ckpt["raft"], strict=True) - self.flow_estimator = model - - cur_f_dims = [128, 96] - f_dims = [256, 128] - - skip_channels = f_dims[-1] // 2 - self.num_flows = 3 - - self.amt_last_cproj = nn.Conv2d(cur_f_dims[0], f_dims[0], 1) - self.amt_second_last_cproj = nn.Conv2d(cur_f_dims[1], f_dims[1], 1) - self.amt_fproj = nn.Conv2d(f_dims[0], f_dims[0], 1) - self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels) - self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels) - - self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0) - self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None) - - self.amt_comb_block = nn.Sequential( - nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), - nn.PReLU(6 * self.num_flows), - nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), - ) - - ################ GIMM settings ################# - self.coord_sampler = CoordSampler3D([-1.0, 1.0]) - - self.g_filter = torch.nn.Parameter( - torch.Tensor( - [ - [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], - [1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0], - [1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0], - ] - ).reshape(1, 1, 1, 3, 3), - requires_grad=False, - ) - self.fwarp_type = "linear" - - self.alpha_v = torch.nn.Parameter(torch.Tensor([1]), requires_grad=True) - self.alpha_fe = torch.nn.Parameter(torch.Tensor([1]), requires_grad=True) - - channel = 32 - in_dim = 2 - self.cnn_encoder = nn.Sequential( - nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), - nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - LateralBlock(channel), - LateralBlock(channel), - LateralBlock(channel), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d( - channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True - ), - ) - channel = 64 - in_dim = 64 - self.res_conv = nn.Sequential( - nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1), - nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - LateralBlock(channel), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d( - channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True - ), - ) - - self.hyponet = HypoNet(add_coord_dim=32) - - def _get_updateblock(self, cdim, scale_factor=None): - return BasicUpdateBlock( - cdim=cdim, - hidden_dim=192, - flow_dim=64, - corr_dim=256, - corr_dim2=192, - fc_dim=188, - scale_factor=scale_factor, - corr_levels=4, - radius=4, - ) - - def cal_bidirection_flow(self, im0, im1, iters=20): - f01, features0, fnet0 = self.flow_estimator( - im0, im1, return_feat=True, iters=20 - ) - f10, features1, fnet1 = self.flow_estimator( - im1, im0, return_feat=True, iters=20 - ) - corr_fn = BidirCorrBlock(self.amt_fproj(fnet0), self.amt_fproj(fnet1), radius=4) - features0 = [ - self.amt_second_last_cproj(features0[0]), - self.amt_last_cproj(features0[1]), - ] - features1 = [ - self.amt_second_last_cproj(features1[0]), - self.amt_last_cproj(features1[1]), - ] - flow01 = f01.unsqueeze(2) - flow10 = f10.unsqueeze(2) - noraml_flows = torch.cat([flow01, -flow10], dim=2) - noraml_flows, flow_scalers = normalize_flow(noraml_flows) - - ori_flows = torch.cat([flow01, flow10], dim=2) - return ( - noraml_flows, - ori_flows, - flow_scalers, - features0, - features1, - corr_fn, - torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2), - ) - - def predict_flow(self, f, cur_coord, cur_t, flows): - def check_for_nans(tensor, name): - if type(tensor) == list: - for t in tensor: - if torch.isnan(t).any(): - print(f"NaNs found in {name}") - else: - if torch.isnan(tensor).any(): - print(f"NaNs found in {name}") - - raft_flow01 = flows[:, :, 0].detach() - raft_flow10 = flows[:, :, 1].detach() - # check_for_nans(raft_flow01, "raft_flow01") - # check_for_nans(raft_flow10, "raft_flow10") - - # calculate splatting metrics - weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10) - # check_for_nans(weights1, "weights1") - # check_for_nans(weights2, "weights2") - strtype = self.fwarp_type + "-zeroeps" - # check_for_nans(f, "f") - - # b,c,h,w - pixel_latent_0 = self.cnn_encoder(f[:, :, 0]) - pixel_latent_1 = self.cnn_encoder(f[:, :, 1]) - # check_for_nans(pixel_latent_0, "pixel_latent_0") - # check_for_nans(pixel_latent_1, "pixel_latent_1") - - tmp_pixel_latent_0 = softsplat( - tenIn=pixel_latent_0, - tenFlow=raft_flow01 * cur_t, - tenMetric=weights1, - strMode=strtype, - ) - tmp_pixel_latent_1 = softsplat( - tenIn=pixel_latent_1, - tenFlow=raft_flow10 * (1 - cur_t), - tenMetric=weights2, - strMode=strtype, - ) - - tmp_pixel_latent = torch.cat([tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1) - # check_for_nans(tmp_pixel_latent, "tmp_pixel_latent") - tmp_pixel_latent = tmp_pixel_latent + self.res_conv( - torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1) - ) - # check_for_nans(tmp_pixel_latent, "tmp_pixel_latent") - permute_idx_range = [i for i in range(1, f.ndim - 1)] - - if cur_coord[1] is None: - outputs = self.hyponet( - cur_coord, - modulation_params_dict=None, - pixel_latent=tmp_pixel_latent.permute(0, 2, 3, 1), - ).permute(0, -1, *permute_idx_range) - else: - outputs = self.hyponet( - cur_coord, - modulation_params_dict=None, - pixel_latent=tmp_pixel_latent.permute(0, 2, 3, 1), - ) - # check_for_nans(outputs, "outputs") - return outputs - - def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1): - ft0 = scale * resize(ft0, scale_factor=scale) - ft1 = scale * resize(ft1, scale_factor=scale) - mask = resize(mask, scale_factor=scale).sigmoid() - img0_warp = warp(img0, ft0) - img1_warp = warp(img1, ft1) - img_warp = mask * img0_warp + (1 - mask) * img1_warp - return img_warp - - def frame_synthesize( - self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None - ): - """ - flow_t: b,2,h,w - cur_t: b,1,1,1 - """ - batch_size = img_xs.shape[0] # b,c,t,h,w - img0 = 2 * img_xs[:, :, 0] - 1.0 - img1 = 2 * img_xs[:, :, 1] - 1.0 - - ##################### update the predicted flow ##################### - ##initialize coordinates for looking up - lookup_coord = build_coord(img_xs[:, :, 0]).to( - img_xs[:, :, 0].device - ) # H//8,W//8 - - flow_t0_fullsize = flow_t * (-cur_t) - flow_t1_fullsize = flow_t * (1.0 - cur_t) - - inv = 1 / 4 - flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv) - flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv) - - ############################# scale 1/4 ############################# - # i. Initialize feature t at scale 1/4 - flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder( - features0[-1], - features1[-1], - flow_t0_inr4, - flow_t1_inr4, - img0=img0, - img1=img1, - ) - features0, features1 = features0[:-1], features1[:-1] - - mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:] - img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4) - img_warp_4 = (img_warp_4 + 1.0) / 2 - img_warp_4 = torch.clamp(img_warp_4, 0, 1) - - corr_4, flow_4_lr = self._amt_corr_scale_lookup( - corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2 - ) - - delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4) - delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) - flowt0_4 = flowt0_4 + delta_flow0_4 - flowt1_4 = flowt1_4 + delta_flow1_4 - ft_4_ = ft_4_ + delta_ft_4_ - - # iii. residue update with lookup corr - corr_4 = resize(corr_4, scale_factor=2.0) - - flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1) - delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4) - flowt0_4 = flowt0_4 + delta_flow_4[:, :2] - flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4] - ft_4_ = ft_4_ + delta_ft_4_ - - ############################# scale 1/1 ############################# - flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder( - ft_4_, - features0[0], - features1[0], - flowt0_4, - flowt1_4, - mask=mask_4_, - img0=img0, - img1=img1, - ) - - if full_img is not None: - img0 = 2 * full_img[:, :, 0] - 1.0 - img1 = 2 * full_img[:, :, 1] - 1.0 - inv = img1.shape[2] / flowt0_1.shape[2] - flowt0_1 = inv * resize(flowt0_1, scale_factor=inv) - flowt1_1 = inv * resize(flowt1_1, scale_factor=inv) - flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv) - flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv) - mask = resize(mask, scale_factor=inv) - img_res = resize(img_res, scale_factor=inv) - - imgt_pred = multi_flow_combine( - self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None - ) - imgt_pred = torch.clamp(imgt_pred, 0, 1) - - ###################################################################### - - flowt0_1 = flowt0_1.reshape( - batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] - ) - flowt1_1 = flowt1_1.reshape( - batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1] - ) - - flowt0_pred = [flowt0_1, flowt0_4] - flowt1_pred = [flowt1_1, flowt1_4] - other_pred = [img_warp_4] - return imgt_pred, flowt0_pred, flowt1_pred, other_pred - - def forward(self, img_xs, coord=None, timestep=None, iters=None, ds_factor=None): - indtype = img_xs.dtype - indevice = img_xs.device - - full_size_img = None - if ds_factor is not None: - full_size_img = img_xs.clone() - img_xs = torch.cat( - [ - resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2), - resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2), - ], - dim=2, - ).to(dtype=indtype, device=indevice) - iters = self.raft_iter if iters is None else iters - ( - normal_flows, - flows, - flow_scalers, - features0, - features1, - corr_fn, - preserved_raft_flows, - ) = self.cal_bidirection_flow( - 255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1], iters=iters - ) - - # List of flows - normal_inr_flows = self.predict_flow(normal_flows, coord, timestep, flows) - cur_flow_t = unnormalize_flow(normal_inr_flows, flow_scalers).squeeze() - - if cur_flow_t.ndim != 4: - cur_flow_t = cur_flow_t.unsqueeze(0) - assert cur_flow_t.ndim == 4 - - imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize( - img_xs, - cur_flow_t, - features0, - features1, - corr_fn, - timestep, - full_img=full_size_img, - ) - - return imgt_pred[:, :, : self.height, : self.width] - - def warp_frame(self, frame, flow): - return warp(frame, flow) - - def compute_psnr(self, preds, targets, reduction="mean"): - assert reduction in ["mean", "sum", "none"] - batch_size = preds.shape[0] - sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean( - dim=-1 - ) - - if reduction == "mean": - psnr = (-10 * torch.log10(sample_mses)).mean() - elif reduction == "sum": - psnr = (-10 * torch.log10(sample_mses)).sum() - else: - psnr = -10 * torch.log10(sample_mses) - - return psnr - - def sample_coord_input( - self, - batch_size, - s_shape, - t_ids, - coord_range=None, - upsample_ratio=1.0, - device=None, - ): - assert device is not None - assert coord_range is None - coord_inputs = self.coord_sampler( - batch_size, s_shape, t_ids, coord_range, upsample_ratio, device - ) - return coord_inputs - - def cal_splatting_weights(self, raft_flow01, raft_flow10): - def check_for_nans(tensor, name): - if type(tensor) == list: - for t in tensor: - if torch.isnan(t).any(): - print(f"NaNs found in {name}") - else: - if torch.isnan(tensor).any(): - print(f"NaNs found in {name}") - - batch_size = raft_flow01.shape[0] - raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0) - - ## flow variance metric - sqaure_mean, mean_square = torch.split( - F.conv3d( - F.pad( - torch.cat([raft_flows**2, raft_flows], 1), - (1, 1, 1, 1), - mode="reflect", - ).unsqueeze(1), - self.g_filter, - ).squeeze(1), - 2, - dim=1, - ) - # check_for_nan(sqaure_mean, "sqaure_mean") - # check_for_nan(mean_square, "mean_square") - var = ( - (sqaure_mean.float() - mean_square.float() ** 2) - .clamp(1e-9, None) - .sqrt() - .mean(1) - .unsqueeze(1) - ).to(raft_flow01.dtype) - # check_for_nan(var, "var") - var01 = var[:batch_size] - var10 = var[batch_size:] - - ## flow warp metirc - f01_warp = -warp(raft_flow10, raft_flow01) - f10_warp = -warp(raft_flow01, raft_flow10) - # check_for_nan(f01_warp, "f01_warp") - # check_for_nan(f10_warp, "f10_warp") - err01 = ( - torch.nn.functional.l1_loss( - input=f01_warp, target=raft_flow01, reduction="none" - ) - .mean(1) - .unsqueeze(1) - ) - err02 = ( - torch.nn.functional.l1_loss( - input=f10_warp, target=raft_flow10, reduction="none" - ) - .mean(1) - .unsqueeze(1) - ) - # check_for_nan(err01, "err01") - # check_for_nan(err02, "err02") - - weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v) - weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v) - # check_for_nan(weights1, "weights1") - # check_for_nan(weights2, "weights2") - - return weights1, weights2 - - def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): - # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 - # based on linear assumption - t0_scale = 1.0 / embt - t1_scale = 1.0 / (1.0 - embt) - if downsample != 1: - inv = 1 / downsample - flow0 = inv * resize(flow0, scale_factor=inv) - flow1 = inv * resize(flow1, scale_factor=inv) - - corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) - corr = torch.cat([corr0, corr1], dim=1) - flow = torch.cat([flow0, flow1], dim=1) - return corr, flow diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/raft.py b/backend/src/pytorch/InterpolateArchs/GIMM/raft.py deleted file mode 100644 index 6aff07062..000000000 --- a/backend/src/pytorch/InterpolateArchs/GIMM/raft.py +++ /dev/null @@ -1,756 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# ginr-ipc: https://github.com/kakaobrain/ginr-ipc -# -------------------------------------------------------- - -import torch -import torch.nn as nn -import math -import einops -import torch.nn.functional as F - -device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "xpu" if torch.xpu.is_available() else "cpu") -backwarp_tenGrid = {} - - -def warp(tenInput, tenFlow): - origdtype = tenInput.dtype - tenFlow = tenFlow.float() - tenInput = tenInput.float() - k = (str(tenFlow.device), str(tenFlow.size())) - if k not in backwarp_tenGrid: - tenHorizontal = ( - torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) - .view(1, 1, 1, tenFlow.shape[3]) - .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) - ).float() - tenVertical = ( - torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) - .view(1, 1, tenFlow.shape[2], 1) - .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) - ).float() - backwarp_tenGrid[k] = ( - torch.cat([tenHorizontal, tenVertical], 1).to(device).float() - ) - - tenFlow = torch.cat( - [ - tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), - tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), - ], - 1, - ).float() - - g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1).float() - pd = 'border' - if tenInput.device.type == "mps": - pd = 'zeros' - g = g.clamp(-1, 1) - return torch.nn.functional.grid_sample( - input=tenInput, - grid=g, - mode="bilinear", - padding_mode=pd, - align_corners=True, - ).to(dtype=origdtype) - - -def normalize_flow(flows): - # FIXME: MULTI-DIMENSION - flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape( - -1, 1, 1, 1, 1 - ) - flows = flows / flow_scaler # [-1,1] - # # Adapt to [0,1] - flows = (flows + 1.0) / 2.0 - return flows, flow_scaler - - -def unnormalize_flow(flows, flow_scaler): - return (flows * 2.0 - 1.0) * flow_scaler - - -def resize(x, scale_factor): - return F.interpolate( - x, scale_factor=scale_factor, mode="bilinear", align_corners=False - ) - - -def coords_grid(batch, ht, wd): - coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) - coords = torch.stack(coords[::-1], dim=0) - return coords[None].repeat(batch, 1, 1, 1) - - -def build_coord(img): - N, C, H, W = img.shape - coords = coords_grid(N, H // 8, W // 8) - return coords - - -def initialize_params(params, init_type, **kwargs): - fan_in, fan_out = params.shape[0], params.shape[1] - if init_type is None or init_type == "normal": - nn.init.normal_(params) - elif init_type == "kaiming_uniform": - nn.init.kaiming_uniform_(params, a=math.sqrt(5)) - elif init_type == "uniform_fan_in": - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(params, -bound, bound) - elif init_type == "zero": - nn.init.zeros_(params) - elif "siren" == init_type: - assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys() - w0 = kwargs["siren_w0"] - if kwargs["is_first"]: - w_std = 1 / fan_in - else: - w_std = math.sqrt(6.0 / fan_in) / w0 - nn.init.uniform_(params, -w_std, w_std) - else: - raise NotImplementedError - - -def create_params_with_init( - shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs -): - if not include_bias: - params = torch.empty([shape[0], shape[1]]) - initialize_params(params, init_type, **kwargs) - return params - else: - params = torch.empty([shape[0] - 1, shape[1]]) - bias = torch.empty([1, shape[1]]) - - initialize_params(params, init_type, **kwargs) - initialize_params(bias, bias_init_type, **kwargs) - return torch.cat([params, bias], dim=0) - - -class CoordSampler3D(nn.Module): - def __init__(self, coord_range, t_coord_only=False): - super().__init__() - self.coord_range = coord_range - self.t_coord_only = t_coord_only - - def shape2coordinate( - self, - batch_size, - spatial_shape, - t_ids, - coord_range=(-1.0, 1.0), - upsample_ratio=1, - device=None, - ): - coords = [] - assert isinstance(t_ids, list) - _coords = torch.tensor(t_ids, device=device) / 1.0 - coords.append(_coords) - for num_s in spatial_shape: - num_s = int(num_s * upsample_ratio) - _coords = (0.5 + torch.arange(num_s, device=device)) / num_s - _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords - coords.append(_coords) - coords = torch.meshgrid(*coords, indexing="ij") - coords = torch.stack(coords, dim=-1) - ones_like_shape = (1,) * coords.ndim - coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) - return coords # (B,T,H,W,3) - - def batchshape2coordinate( - self, - batch_size, - spatial_shape, - t_ids, - coord_range=(-1.0, 1.0), - upsample_ratio=1, - device=None, - ): - coords = [] - _coords = torch.tensor(1, device=device) - coords.append(_coords) - for num_s in spatial_shape: - num_s = int(num_s * upsample_ratio) - _coords = (0.5 + torch.arange(num_s, device=device)) / num_s - _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords - coords.append(_coords) - coords = torch.meshgrid(*coords, indexing="ij") - coords = torch.stack(coords, dim=-1) - ones_like_shape = (1,) * coords.ndim - # Now coords b,1,h,w,3, coords[...,0]=1. - coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) - # assign per-sample timestep within the batch - coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1) - return coords - - def forward( - self, - batch_size, - s_shape, - t_ids, - coord_range=None, - upsample_ratio=1.0, - device=None, - ): - coord_range = self.coord_range if coord_range is None else coord_range - if isinstance(t_ids, list): - coords = self.shape2coordinate( - batch_size, s_shape, t_ids, coord_range, upsample_ratio, device - ) - elif isinstance(t_ids, torch.Tensor): - coords = self.batchshape2coordinate( - batch_size, s_shape, t_ids, coord_range, upsample_ratio, device - ) - if self.t_coord_only: - coords = coords[..., :1] - return coords - - -# define siren layer & Siren model -class Sine(nn.Module): - """Sine activation with scaling. - - Args: - w0 (float): Omega_0 parameter from SIREN paper. - """ - - def __init__(self, w0=1.0): - super().__init__() - self.w0 = w0 - - def forward(self, x): - return torch.sin(self.w0 * x) - - -class HypoNet(nn.Module): - r""" - The Hyponetwork with a coordinate-based MLP to be modulated. - """ - - def __init__(self, add_coord_dim=32): - super().__init__() - self.use_bias = True - self.num_layer = 5 - self.hidden_dims = [128] - self.add_coord_dim = add_coord_dim - - if len(self.hidden_dims) == 1: - self.hidden_dims = self.hidden_dims * ( - self.num_layer - 1 - ) # exclude output layer - else: - assert len(self.hidden_dims) == self.num_layer - 1 - - # after computes the shape of trainable parameters, initialize them - self.params_dict = None - self.params_shape_dict = self.compute_params_shape() - self.activation = Sine(1.0) - self.build_base_params_dict() - self.output_bias = 0.5 - - self.normalize_weight = True - - self.ignore_base_param_dict = {name: False for name in self.params_dict} - - @staticmethod - def subsample_coords(coords, subcoord_idx=None): - if subcoord_idx is None: - return coords - - batch_size = coords.shape[0] - sub_coords = [] - coords = coords.view(batch_size, -1, coords.shape[-1]) - for idx in range(batch_size): - sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]]) - sub_coords = torch.cat(sub_coords, dim=0) - return sub_coords - - def forward(self, coord, modulation_params_dict=None, pixel_latent=None): - origdtype = coord[0].dtype - sub_idx = None - if isinstance(coord, tuple): - coord, sub_idx = coord[0], coord[1] - - if modulation_params_dict is not None: - self.check_valid_param_keys(modulation_params_dict) - - batch_size, coord_shape, input_dim = ( - coord.shape[0], - coord.shape[1:-1], - coord.shape[-1], - ) - coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates - assert pixel_latent is not None - pixel_latent = F.interpolate( - pixel_latent.permute(0, 3, 1, 2), - size=(coord_shape[1], coord_shape[2]), - mode="bilinear", - ).permute(0, 2, 3, 1) - pixel_latent_dim = pixel_latent.shape[-1] - pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim) - hidden = coord - - hidden = torch.cat([pixel_latent, hidden], dim=-1) - - hidden = self.subsample_coords(hidden, sub_idx) - - for idx in range(5): - param_key = f"linear_wb{idx}" - base_param = einops.repeat( - self.params_dict[param_key], "n m -> b n m", b=batch_size - ) - - if (modulation_params_dict is not None) and ( - param_key in modulation_params_dict.keys() - ): - modulation_param = modulation_params_dict[param_key] - else: - modulation_param = torch.ones_like(base_param[:, :-1]) - - ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device) - hidden = torch.cat([hidden, ones], dim=-1).to(dtype=origdtype) - - base_param_w, base_param_b = ( - base_param[:, :-1, :], - base_param[:, -1:, :], - ) - - if self.ignore_base_param_dict[param_key]: - base_param_w = 1.0 - param_w = base_param_w * modulation_param - if self.normalize_weight: - param_w = F.normalize(param_w, dim=1) - modulated_param = torch.cat([param_w, base_param_b], dim=1) - - # print([param_key,hidden.shape,modulated_param.shape]) - hidden = torch.bmm(hidden, modulated_param) - - if idx < (5 - 1): - hidden = self.activation(hidden) - - outputs = hidden + self.output_bias - if sub_idx is None: - outputs = outputs.view(batch_size, *coord_shape, -1) - return outputs - - def compute_params_shape(self): - """ - Computes the shape of MLP parameters. - The computed shapes are used to build the initial weights by `build_base_params_dict`. - """ - use_bias = self.use_bias - - param_shape_dict = dict() - - fan_in = 3 - add_dim = self.add_coord_dim - fan_in = fan_in + add_dim - fan_in = fan_in + 1 if use_bias else fan_in - - for i in range(4): - fan_out = self.hidden_dims[i] - param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out) - fan_in = fan_out + 1 if use_bias else fan_out - - param_shape_dict[f"linear_wb{4}"] = (fan_in, 2) - return param_shape_dict - - def build_base_params_dict(self): - assert self.params_shape_dict - params_dict = nn.ParameterDict() - for idx, (name, shape) in enumerate(self.params_shape_dict.items()): - is_first = idx == 0 - params = create_params_with_init( - shape, - init_type="siren", - include_bias=self.use_bias, - bias_init_type="siren", - is_first=is_first, - siren_w0=1.0, # valid only for siren - ) - params = nn.Parameter(params) - params_dict[name] = params - self.set_params_dict(params_dict) - - def check_valid_param_keys(self, params_dict): - predefined_params_keys = self.params_shape_dict.keys() - for param_key in params_dict.keys(): - if param_key in predefined_params_keys: - continue - else: - raise KeyError - - def set_params_dict(self, params_dict): - self.check_valid_param_keys(params_dict) - self.params_dict = params_dict - - -class LateralBlock(nn.Module): - def __init__(self, dim): - super(LateralBlock, self).__init__() - self.layers = nn.Sequential( - nn.Conv2d(dim, dim, 3, 1, 1, bias=True), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(dim, dim, 3, 1, 1, bias=True), - ) - - def forward(self, x): - res = x - x = self.layers(x) - return x + res - - -def convrelu( - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - dilation=1, - groups=1, - bias=True, -): - return nn.Sequential( - nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias=bias, - ), - nn.PReLU(out_channels), - ) - - -def multi_flow_combine( - comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None -): - assert mean is None - b, c, h, w = flow0.shape - num_flows = c // 2 - flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - - mask = ( - mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w) - if mask is not None - else None - ) - img_res = ( - img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w) - if img_res is not None - else 0 - ) - img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) - img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) - mean = ( - torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1) - if mean is not None - else 0 - ) - - img0_warp = warp(img0, flow0) - img1_warp = warp(img1, flow1) - img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res - img_warps = img_warps.reshape(b, num_flows, 3, h, w) - - res = comb_block(img_warps.view(b, -1, h, w)) - imgt_pred = img_warps.mean(1) + res - - imgt_pred = (imgt_pred + 1.0) / 2 - - return imgt_pred - - -class ResBlock(nn.Module): - def __init__(self, in_channels, side_channels, bias=True): - super(ResBlock, self).__init__() - self.side_channels = side_channels - self.conv1 = nn.Sequential( - nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias - ), - nn.PReLU(in_channels), - ) - self.conv2 = nn.Sequential( - nn.Conv2d( - side_channels, - side_channels, - kernel_size=3, - stride=1, - padding=1, - bias=bias, - ), - nn.PReLU(side_channels), - ) - self.conv3 = nn.Sequential( - nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias - ), - nn.PReLU(in_channels), - ) - self.conv4 = nn.Sequential( - nn.Conv2d( - side_channels, - side_channels, - kernel_size=3, - stride=1, - padding=1, - bias=bias, - ), - nn.PReLU(side_channels), - ) - self.conv5 = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias - ) - self.prelu = nn.PReLU(in_channels) - - def forward(self, x): - out = self.conv1(x) - - res_feat = out[:, : -self.side_channels, ...] - side_feat = out[:, -self.side_channels :, :, :] - side_feat = self.conv2(side_feat) - out = self.conv3(torch.cat([res_feat, side_feat], 1)) - - res_feat = out[:, : -self.side_channels, ...] - side_feat = out[:, -self.side_channels :, :, :] - side_feat = self.conv4(side_feat) - out = self.conv5(torch.cat([res_feat, side_feat], 1)) - - out = self.prelu(x + out) - return out - - -class BasicUpdateBlock(nn.Module): - def __init__( - self, - cdim, - hidden_dim, - flow_dim, - corr_dim, - corr_dim2, - fc_dim, - corr_levels=4, - radius=3, - scale_factor=None, - out_num=1, - ): - super(BasicUpdateBlock, self).__init__() - cor_planes = corr_levels * (2 * radius + 1) ** 2 - - self.scale_factor = scale_factor - self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) - self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) - self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3) - self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1) - self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1) - - self.gru = nn.Sequential( - nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - ) - - self.feat_head = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, cdim, 3, padding=1), - ) - - self.flow_head = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1), - ) - - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, net, flow, corr): - net = ( - resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net - ) - cor = self.lrelu(self.convc1(corr)) - cor = self.lrelu(self.convc2(cor)) - flo = self.lrelu(self.convf1(flow)) - flo = self.lrelu(self.convf2(flo)) - cor_flo = torch.cat([cor, flo], dim=1) - inp = self.lrelu(self.conv(cor_flo)) - inp = torch.cat([inp, flow, net], dim=1) - - out = self.gru(inp) - delta_net = self.feat_head(out) - delta_flow = self.flow_head(out) - - if self.scale_factor is not None: - delta_net = resize(delta_net, scale_factor=self.scale_factor) - delta_flow = self.scale_factor * resize( - delta_flow, scale_factor=self.scale_factor - ) - return delta_net, delta_flow - - -def get_bn(): - return nn.BatchNorm2d - - -class NewInitDecoder(nn.Module): - def __init__(self, in_ch, skip_ch): - super().__init__() - norm_layer = get_bn() - - self.upsample = nn.Sequential( - nn.PixelShuffle(2), - convrelu(in_ch // 4, in_ch // 4, 5, 1, 2), - convrelu(in_ch // 4, in_ch // 4), - convrelu(in_ch // 4, in_ch // 4), - convrelu(in_ch // 4, in_ch // 4), - convrelu(in_ch // 4, in_ch // 2), - nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1), - norm_layer(in_ch // 2), - nn.ReLU(inplace=True), - ) - - in_ch = in_ch // 2 - self.convblock = nn.Sequential( - convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0), - ResBlock(in_ch, skip_ch), - ResBlock(in_ch, skip_ch), - ResBlock(in_ch, skip_ch), - nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True), - ) - - def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None): - f0 = self.upsample(f0) - f1 = self.upsample(f1) - f0_warp_ks = warp(f0, flow0_in) - f1_warp_ks = warp(f1, flow1_in) - - f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1) - - assert img0 is not None - assert img1 is not None - scale_factor = f_in.shape[2] / img0.shape[2] - img0 = resize(img0, scale_factor=scale_factor) - img1 = resize(img1, scale_factor=scale_factor) - warped_img0 = warp(img0, flow0_in) - warped_img1 = warp(img1, flow1_in) - f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1) - - out = self.convblock(f_in) - ft_ = out[:, 4:, ...] - flow0 = flow0_in + out[:, :2, ...] - flow1 = flow1_in + out[:, 2:4, ...] - return flow0, flow1, ft_ - - -class NewMultiFlowDecoder(nn.Module): - def __init__(self, in_ch, skip_ch, num_flows=3): - super(NewMultiFlowDecoder, self).__init__() - norm_layer = get_bn() - - self.upsample = nn.Sequential( - nn.PixelShuffle(2), - nn.PixelShuffle(2), - convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2), - convrelu(in_ch // 4, in_ch // 4), - convrelu(in_ch // 4, in_ch // 4), - convrelu(in_ch // 4, in_ch // 4), - convrelu(in_ch // 4, in_ch // 2), - nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1), - norm_layer(in_ch // 2), - nn.ReLU(inplace=True), - ) - - self.num_flows = num_flows - ch_factor = 2 - self.convblock = nn.Sequential( - convrelu(in_ch * ch_factor + 17, in_ch * ch_factor), - ResBlock(in_ch * ch_factor, skip_ch), - ResBlock(in_ch * ch_factor, skip_ch), - ResBlock(in_ch * ch_factor, skip_ch), - nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1), - ) - - def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None): - f0 = self.upsample(f0) - # print([f1.shape,f0.shape]) - f1 = self.upsample(f1) - n = self.num_flows - flow0 = 4.0 * resize(flow0, scale_factor=4.0) - flow1 = 4.0 * resize(flow1, scale_factor=4.0) - - ft_ = resize(ft_, scale_factor=4.0) - mask = resize(mask, scale_factor=4.0) - f0_warp = warp(f0, flow0) - f1_warp = warp(f1, flow1) - - f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1) - - assert mask is not None - f_in = torch.cat([f_in, mask], 1) - - assert img0 is not None - assert img1 is not None - warped_img0 = warp(img0, flow0) - warped_img1 = warp(img1, flow1) - f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1) - - out = self.convblock(f_in) - delta_flow0, delta_flow1, delta_mask, img_res = torch.split( - out, [2 * n, 2 * n, n, 3 * n], 1 - ) - mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1) - mask = torch.sigmoid(mask) - flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1) - flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1) - - return flow0, flow1, mask, img_res - - -def multi_flow_combine( - comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None -): - assert mean is None - b, c, h, w = flow0.shape - num_flows = c // 2 - flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - - mask = ( - mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w) - if mask is not None - else None - ) - img_res = ( - img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w) - if img_res is not None - else 0 - ) - img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) - img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) - mean = ( - torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1) - if mean is not None - else 0 - ) - - img0_warp = warp(img0, flow0) - img1_warp = warp(img1, flow1) - img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res - img_warps = img_warps.reshape(b, num_flows, 3, h, w) - - res = comb_block(img_warps.view(b, -1, h, w)) - imgt_pred = img_warps.mean(1) + res - - imgt_pred = (imgt_pred + 1.0) / 2 - - return imgt_pred diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/raftarch.py b/backend/src/pytorch/InterpolateArchs/GIMM/raftarch.py deleted file mode 100644 index dc1e692fe..000000000 --- a/backend/src/pytorch/InterpolateArchs/GIMM/raftarch.py +++ /dev/null @@ -1,745 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class FlowHead(nn.Module): - def __init__(self, input_dim=128, hidden_dim=256): - super(FlowHead, self).__init__() - self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) - self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - return self.conv2(self.relu(self.conv1(x))) - - -class ConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192 + 128): - super(ConvGRU, self).__init__() - self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) - self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) - self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) - - def forward(self, h, x): - hx = torch.cat([h, x], dim=1) - - z = torch.sigmoid(self.convz(hx)) - r = torch.sigmoid(self.convr(hx)) - q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) - - h = (1 - z) * h + z * q - return h - - -class SepConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192 + 128): - super(SepConvGRU, self).__init__() - self.convz1 = nn.Conv2d( - hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) - ) - self.convr1 = nn.Conv2d( - hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) - ) - self.convq1 = nn.Conv2d( - hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) - ) - - self.convz2 = nn.Conv2d( - hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) - ) - self.convr2 = nn.Conv2d( - hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) - ) - self.convq2 = nn.Conv2d( - hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) - ) - - def forward(self, h, x): - # horizontal - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz1(hx)) - r = torch.sigmoid(self.convr1(hx)) - q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) - h = (1 - z) * h + z * q - - # vertical - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz2(hx)) - r = torch.sigmoid(self.convr2(hx)) - q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) - h = (1 - z) * h + z * q - - return h - - -class SmallMotionEncoder(nn.Module): - def __init__(self): - super(SmallMotionEncoder, self).__init__() - cor_planes = 4 * (2 * 4 + 1) ** 2 - self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) - self.convf1 = nn.Conv2d(2, 64, 7, padding=3) - self.convf2 = nn.Conv2d(64, 32, 3, padding=1) - self.conv = nn.Conv2d(128, 80, 3, padding=1) - - def forward(self, flow, corr): - cor = F.relu(self.convc1(corr)) - flo = F.relu(self.convf1(flow)) - flo = F.relu(self.convf2(flo)) - cor_flo = torch.cat([cor, flo], dim=1) - out = F.relu(self.conv(cor_flo)) - return torch.cat([out, flow], dim=1) - - -class BasicMotionEncoder(nn.Module): - def __init__(self): - super(BasicMotionEncoder, self).__init__() - cor_planes = 4 * (2 * 4 + 1) ** 2 - self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) - self.convc2 = nn.Conv2d(256, 192, 3, padding=1) - self.convf1 = nn.Conv2d(2, 128, 7, padding=3) - self.convf2 = nn.Conv2d(128, 64, 3, padding=1) - self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) - - def forward(self, flow, corr): - cor = F.relu(self.convc1(corr)) - cor = F.relu(self.convc2(cor)) - flo = F.relu(self.convf1(flow)) - flo = F.relu(self.convf2(flo)) - - cor_flo = torch.cat([cor, flo], dim=1) - out = F.relu(self.conv(cor_flo)) - return torch.cat([out, flow], dim=1) - - -class SmallUpdateBlock(nn.Module): - def __init__(self, args, hidden_dim=96): - super(SmallUpdateBlock, self).__init__() - self.encoder = SmallMotionEncoder(args) - self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) - self.flow_head = FlowHead(hidden_dim, hidden_dim=128) - - def forward(self, net, inp, corr, flow): - motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) - net = self.gru(net, inp) - delta_flow = self.flow_head(net) - - return net, None, delta_flow - - -class BasicUpdateBlock(nn.Module): - def __init__(self, hidden_dim=128, input_dim=128): - super(BasicUpdateBlock, self).__init__() - self.encoder = BasicMotionEncoder() - self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) - self.flow_head = FlowHead(hidden_dim, hidden_dim=256) - - self.mask = nn.Sequential( - nn.Conv2d(128, 256, 3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 64 * 9, 1, padding=0), - ) - - def forward(self, net, inp, corr, flow, upsample=True): - motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) - - net = self.gru(net, inp) - delta_flow = self.flow_head(net) - - # scale mask to balence gradients - mask = 0.25 * self.mask(net) - return net, mask, delta_flow - - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ResidualBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn="group", stride=1): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, padding=1, stride=stride - ) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == "none": - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 - ) - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x + y) - - -class BottleneckBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn="group", stride=1): - super(BottleneckBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) - self.conv2 = nn.Conv2d( - planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride - ) - self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(planes // 4) - self.norm2 = nn.BatchNorm2d(planes // 4) - self.norm3 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm4 = nn.BatchNorm2d(planes) - - elif norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(planes // 4) - self.norm2 = nn.InstanceNorm2d(planes // 4) - self.norm3 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm4 = nn.InstanceNorm2d(planes) - - elif norm_fn == "none": - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - self.norm3 = nn.Sequential() - if not stride == 1: - self.norm4 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 - ) - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - y = self.relu(self.norm3(self.conv3(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x + y) - - -class BasicEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0, only_feat=False): - super(BasicEncoder, self).__init__() - self.norm_fn = norm_fn - self.only_feat = only_feat - - if self.norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) - - elif self.norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(64) - - elif self.norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(64) - - elif self.norm_fn == "none": - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 64 - self.layer1 = self._make_layer(64, stride=1) - self.layer2 = self._make_layer(96, stride=2) - self.layer3 = self._make_layer(128, stride=2) - - if not self.only_feat: - # output convolution - self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - def forward(self, x, return_feature=False, mif=False): - features = [] - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x_2 = F.interpolate(x, scale_factor=1 / 2, mode="bilinear", align_corners=False) - x_4 = F.interpolate(x, scale_factor=1 / 4, mode="bilinear", align_corners=False) - - def f1(feat): - feat = self.conv1(feat) - feat = self.norm1(feat) - feat = self.relu1(feat) - feat = self.layer1(feat) - return feat - - x = f1(x) - features.append(x) - x = self.layer2(x) - if mif: - x_2_2 = f1(x_2) - features.append(torch.cat([x, x_2_2], dim=1)) - else: - features.append(x) - x = self.layer3(x) - if mif: - x_2_4 = self.layer2(x_2_2) - x_4_4 = f1(x_4) - features.append(torch.cat([x, x_2_4, x_4_4], dim=1)) - else: - features.append(x) - - if not self.only_feat: - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - features = [torch.split(f, [batch_dim, batch_dim], dim=0) for f in features] - if return_feature: - return x, features - else: - return x - - -class SmallEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): - super(SmallEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) - - elif self.norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(32) - - elif self.norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(32) - - elif self.norm_fn == "none": - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 32 - self.layer1 = self._make_layer(32, stride=1) - self.layer2 = self._make_layer(64, stride=2) - self.layer3 = self._make_layer(96, stride=2) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - def forward(self, x): - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x - - -def bilinear_sampler(img, coords, mode="bilinear", mask=False): - orig_dtype = img.dtype - img = img.float() - coords = coords.float() - """Wrapper for grid_sample, uses pixel coordinates""" - H, W = img.shape[-2:] - xgrid, ygrid = coords.split([1, 1], dim=-1) - xgrid = 2 * xgrid / (W - 1) - 1 - ygrid = 2 * ygrid / (H - 1) - 1 - - grid = torch.cat([xgrid, ygrid], dim=-1) - img = F.grid_sample(img, grid, align_corners=True) - - if mask: - mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) - return img, mask - - return img.to(dtype=orig_dtype) - - -class CorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - self.corr_pyramid = [] - - # all pairs correlation - corr = CorrBlock.corr(fmap1, fmap2) - - batch, h1, w1, dim, h2, w2 = corr.shape - corr = corr.reshape(batch * h1 * w1, dim, h2, w2) - - self.corr_pyramid.append(corr) - for i in range(self.num_levels - 1): - corr = F.avg_pool2d(corr, 2, stride=2) - self.corr_pyramid.append(corr) - - def __call__(self, coords): - r = self.radius - coords = coords.permute(0, 2, 3, 1) - batch, h1, w1, _ = coords.shape - - out_pyramid = [] - for i in range(self.num_levels): - corr = self.corr_pyramid[i] - dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) - dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) - delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) - - centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i - delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) - coords_lvl = centroid_lvl + delta_lvl - - corr = bilinear_sampler(corr, coords_lvl) - corr = corr.view(batch, h1, w1, -1) - out_pyramid.append(corr) - - out = torch.cat(out_pyramid, dim=-1) - return out.permute(0, 3, 1, 2).contiguous() - - @staticmethod - def corr(fmap1, fmap2): - batch, dim, ht, wd = fmap1.shape - fmap1 = fmap1.view(batch, dim, ht * wd) - fmap2 = fmap2.view(batch, dim, ht * wd) - - corr = torch.matmul(fmap1.transpose(1, 2), fmap2) - corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim)) - - -def coords_grid(batch, ht, wd, device): - coords = torch.meshgrid( - torch.arange(ht, device=device), torch.arange(wd, device=device) - ) - coords = torch.stack(coords[::-1], dim=0) - return coords[None].repeat(batch, 1, 1, 1) - - -def upflow8(flow, mode="bilinear"): - new_size = (8 * flow.shape[2], 8 * flow.shape[3]) - return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) - - -try: - autocast = torch.cuda.amp.autocast -except: - # dummy autocast for PyTorch < 1.6 - class autocast: - def __init__(self, enabled): - pass - - def __enter__(self): - pass - - def __exit__(self, *args): - pass - - -class RAFT(nn.Module): - def __init__(self): - super(RAFT, self).__init__() - - self.hidden_dim = hdim = 128 - self.context_dim = cdim = 128 - self.corr_levels = 4 - self.corr_radius = 4 - self.corr_levels = 4 - self.corr_radius = 4 - - self.dropout = 0 - - self.alternate_corr = False - - # feature network, context network, and update block - self.fnet = BasicEncoder( - output_dim=256, norm_fn="instance", dropout=self.dropout - ) - self.cnet = BasicEncoder( - output_dim=hdim + cdim, norm_fn="batch", dropout=self.dropout - ) - self.update_block = BasicUpdateBlock(hidden_dim=hdim) - - def freeze_bn(self): - for m in self.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eval() - - def initialize_flow(self, img): - """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" - N, C, H, W = img.shape - coords0 = coords_grid(N, H // 8, W // 8, device=img.device) - coords1 = coords_grid(N, H // 8, W // 8, device=img.device) - - # optical flow computed as difference: flow = coords1 - coords0 - return coords0, coords1 - - def upsample_flow(self, flow, mask): - """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" - N, _, H, W = flow.shape - mask = mask.view(N, 1, 9, 8, 8, H, W) - mask = torch.softmax(mask, dim=2) - - up_flow = F.unfold(8 * flow, [3, 3], padding=1) - up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) - - up_flow = torch.sum(mask * up_flow, dim=2) - up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, 2, 8 * H, 8 * W) - - def forward( - self, - image1, - image2, - iters=12, - flow_init=None, - upsample=True, - test_mode=False, - return_feat=True, - ): - """Estimate optical flow between pair of frames""" - imgdtype = image1.dtype - - image1 = 2 * (image1 / 255.0) - 1.0 - image2 = 2 * (image2 / 255.0) - 1.0 - - image1 = image1.contiguous() - image2 = image2.contiguous() - - hdim = self.hidden_dim - cdim = self.context_dim - - # run the feature network - with autocast(enabled=False): - fmap1, fmap2 = self.fnet([image1, image2]) - - corr_fn = CorrBlock(fmap1, fmap2, radius=4) - - # run the context network - with autocast(enabled=False): # no mixed precision - cnet, feats = self.cnet(image1, return_feature=True) - net, inp = torch.split(cnet, [hdim, cdim], dim=1) - net = torch.tanh(net) - inp = torch.relu(inp) - - coords0, coords1 = self.initialize_flow(image1) - - if flow_init is not None: - coords1 = coords1 + flow_init - - flow_predictions = [] - for itr in range(iters): - coords1 = coords1.detach() - corr = corr_fn(coords1) # index correlation volume - - flow = coords1 - coords0 - with autocast(enabled=False): - net, up_mask, delta_flow = self.update_block( - net.to(dtype=imgdtype), - inp.to(dtype=imgdtype), - corr.to(dtype=imgdtype), - flow.to(dtype=imgdtype), - ) - - # F(t+1) = F(t) + \Delta(t) - coords1 = coords1 + delta_flow - - # upsample predictions - if up_mask is None: - flow_up = upflow8(coords1 - coords0) - else: - flow_up = self.upsample_flow(coords1 - coords0, up_mask) - - flow_predictions.append(flow_up) - - if test_mode: - return coords1 - coords0, flow_up - - if return_feat: - return flow_up, feats[1:], fmap1 - - return flow_predictions - - -class BidirCorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - self.corr_pyramid = [] - self.corr_pyramid_T = [] - - corr = BidirCorrBlock.corr(fmap1, fmap2) - batch, h1, w1, dim, h2, w2 = corr.shape - corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) - - corr = corr.reshape(batch * h1 * w1, dim, h2, w2) - corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1) - - self.corr_pyramid.append(corr) - self.corr_pyramid_T.append(corr_T) - - for _ in range(self.num_levels - 1): - corr = F.avg_pool2d(corr, 2, stride=2) - corr_T = F.avg_pool2d(corr_T, 2, stride=2) - self.corr_pyramid.append(corr) - self.corr_pyramid_T.append(corr_T) - - def __call__(self, coords0, coords1): - r = self.radius - coords0 = coords0.permute(0, 2, 3, 1) - coords1 = coords1.permute(0, 2, 3, 1) - assert coords0.shape == coords1.shape, ( - f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" - ) - batch, h1, w1, _ = coords0.shape - - out_pyramid = [] - out_pyramid_T = [] - for i in range(self.num_levels): - corr = self.corr_pyramid[i] - corr_T = self.corr_pyramid_T[i] - - dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) - dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device) - delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) - delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) - - centroid_lvl_0 = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i - centroid_lvl_1 = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i - coords_lvl_0 = centroid_lvl_0 + delta_lvl - coords_lvl_1 = centroid_lvl_1 + delta_lvl - - corr = bilinear_sampler(corr, coords_lvl_0) - corr_T = bilinear_sampler(corr_T, coords_lvl_1) - corr = corr.view(batch, h1, w1, -1) - corr_T = corr_T.view(batch, h1, w1, -1) - out_pyramid.append(corr) - out_pyramid_T.append(corr_T) - - out = torch.cat(out_pyramid, dim=-1) - out_T = torch.cat(out_pyramid_T, dim=-1) - return ( - out.permute(0, 3, 1, 2).contiguous(), - out_T.permute(0, 3, 1, 2).contiguous(), - ) - - @staticmethod - def corr(fmap1, fmap2): - batch, dim, ht, wd = fmap1.shape - fmap1 = fmap1.view(batch, dim, ht * wd) - fmap2 = fmap2.view(batch, dim, ht * wd) - - corr = torch.matmul(fmap1.transpose(1, 2), fmap2) - corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim)) diff --git a/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py b/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py index ecbf74e2e..e78b8a6fe 100644 --- a/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py +++ b/backend/src/pytorch/InterpolateArchs/GMFSS/GMFSS.py @@ -7,14 +7,9 @@ from .gmflow.gmflow import GMFlow from .MetricNet import MetricNet from .FusionNet_u import GridNet -from ....constants import HAS_SYSTEM_CUDA +from ....utils.Util import CudaChecker from ..DetectInterpolateArch import ArchDetect -if HAS_SYSTEM_CUDA: - from ..util.softsplat_cupy import softsplat -else: - from ..util.softsplat_torch import softsplat - class GMFSS: def __init__( @@ -44,6 +39,10 @@ def __init__( tmp = max(_pad, int(_pad / self.scale)) self.pw = math.ceil(self.width / tmp) * tmp self.ph = math.ceil(self.height / tmp) * tmp + if CudaChecker.checkForCUDA(): + from ..util.softsplat_cupy import softsplat + else: + from ..util.softsplat_torch import softsplat self.warp = softsplat combined_state_dict = torch.load(model_path, map_location="cpu") @@ -61,7 +60,7 @@ def __init__( # model unspecific setup self.ifnet = IFNet(ensemble=ensemble).to(dtype=dtype, device=device) - self.flownet = GMFlow().to(dtype=torch.float, device=device) + self.flownet = GMFlow().to(dtype=dtype, device=device) self.metricnet = MetricNet().to(dtype=dtype, device=device) self.feat_ext = FeatureNet().to(dtype=dtype, device=device) self.fusionnet = GridNet().to(dtype=dtype, device=device) @@ -114,9 +113,9 @@ def __init__( import gc gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_max_memory_cached() + #torch.cuda.empty_cache() + #torch.cuda.reset_max_memory_allocated() + #torch.cuda.reset_max_memory_cached() self.ifnet = trtHandler.load_engine("IFNet.engine") self.feat_ext = trtHandler.load_engine("Feat.engine") self.flownet = trtHandler.load_engine("Flownet.engine") @@ -138,8 +137,8 @@ def forward(self, img0, img1, timestep, scale=None): imgf0 = img0 imgf1 = img1 if self.flow01 is None: - self.flow01 = self.flownet(imgf0.float(), imgf1.float()).to(dtype=self.dtype) - self.flow10 = self.flownet(imgf1.float(), imgf0.float()).to(dtype=self.dtype) + self.flow01 = self.flownet(imgf0, imgf1) + self.flow10 = self.flownet(imgf1, imgf0) if self.scale != 1.0: self.flow01 = ( F.interpolate( diff --git a/backend/src/pytorch/InterpolateGIMM.py b/backend/src/pytorch/InterpolateGIMM.py deleted file mode 100644 index 8faef533b..000000000 --- a/backend/src/pytorch/InterpolateGIMM.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch - -from .TorchUtils import TorchUtils -# from backend.src.pytorch.InterpolateArchs.GIMM import GIMM -from .BaseInterpolate import BaseInterpolate -import math -import logging -import sys -from ..utils.Util import ( - warnAndLog, - log, -) - -from ..utils.Util import CudaChecker -from time import sleep - -torch.set_float32_matmul_precision("medium") -torch.set_grad_enabled(False) -logging.basicConfig(level=logging.INFO) - -class InterpolateGIMMTorch(BaseInterpolate): - @torch.inference_mode() - def __init__( - self, - modelPath: str, - ceilInterpolateFactor: int = 2, - width: int = 1920, - height: int = 1080, - device: str = "default", - dtype: str = "auto", - backend: str = "pytorch", - UHDMode: bool = False, - ensemble: bool = False, - dynamicScaledOpticalFlow: bool = False, - gpu_id: int = 0, - hdr_mode: bool = False, - *args, - **kwargs, - ): - self.interpolateModel = modelPath - self.width = width - self.height = height - _pad = 64 - self.scale = 0.5 # GIMM uses fat amounts of vram, needs really low flow resolution for regular resolutions - if UHDMode: - self.scale = 0.25 # GIMM uses fat amounts of vram, needs really low flow resolution for UHD - tmp = max(_pad, int(_pad / self.scale)) - self.pw = math.ceil(self.width / tmp) * tmp - self.ph = math.ceil(self.height / tmp) * tmp - padding = (0, self.pw - self.width, 0, self.ph - self.height) - self.torchUtils = TorchUtils( - width=width, - height=height, - hdr_mode=hdr_mode, - padding=padding, - device_type=device, - ) - self.device = self.torchUtils.handle_device(device, gpu_id=gpu_id) - self.dtype = self.torchUtils.handle_precision(dtype) - if ensemble: - print("Ensemble is not implemented for GIMM, disabling", file=sys.stderr) - if dynamicScaledOpticalFlow: - print( - "Dynamic Scaled Optical Flow is not implemented for GIMM, disabling", - file=sys.stderr, - ) - - self.backend = backend - self.ceilInterpolateFactor = ceilInterpolateFactor - self.hdr_mode = hdr_mode # used in base interpolate class (ik inheritance is bad leave me alone) - self.frame0 = None - - self.doEncodingOnFrame = False - self._load() - - @torch.inference_mode() - def _load(self): - self.stream = self.torchUtils.init_stream() - self.prepareStream = self.torchUtils.init_stream() - self.copyStream = self.torchUtils.init_stream() - with self.torchUtils.run_stream(self.prepareStream): # type: ignore - from .InterpolateArchs.GIMM.gimmvfi_r import GIMMVFI_R - - self.flownet = GIMMVFI_R( - model_path=self.interpolateModel, width=self.width, height=self.height - ) - state_dict = torch.load(self.interpolateModel, map_location=self.device)[ - "gimmvfi_r" - ] - self.flownet.load_state_dict(state_dict) - self.flownet.eval().to(device=self.device, dtype=self.dtype) - - - - dummyInput = torch.zeros( - [1, 3, self.ph, self.pw], dtype=self.dtype, device=self.device - ) - dummyInput2 = torch.zeros( - [1, 3, self.ph, self.pw], dtype=self.dtype, device=self.device - ) - xs = torch.cat( - (dummyInput.unsqueeze(2), dummyInput2.unsqueeze(2)), dim=2 - ).to(self.device, non_blocking=True) - s_shape = xs.shape[-2:] - - # caching the timestep tensor in a dict with the timestep as a float for the key - - self.timestepDict = {} - self.coordDict = {} - - for n in range(self.ceilInterpolateFactor): - timestep = n / (self.ceilInterpolateFactor) - timestep_tens = ( - n - * 1 - / self.ceilInterpolateFactor - * torch.ones(xs.shape[0]) - .to(xs.device) - .to(self.dtype) - .reshape(-1, 1, 1, 1) - ) - self.timestepDict[timestep] = timestep_tens - coord = ( - self.flownet.sample_coord_input( - 1, - s_shape, - [1 / self.ceilInterpolateFactor * n], - device=self.device, - upsample_ratio=self.scale, - ).to(non_blocking=True, dtype=self.dtype, device=self.device), - None, - ) - self.coordDict[timestep] = coord - - log("GIMM loaded") - log("Scale: " + str(self.scale)) - HAS_SYSTEM_CUDA = CudaChecker().HAS_SYSTEM_CUDA - log("Using System CUDA: " + str(HAS_SYSTEM_CUDA)) - if not HAS_SYSTEM_CUDA: - print( - "WARNING: System CUDA not found, falling back to PyTorch softsplat. This will be a bit slower.", - file=sys.stderr, - ) - if self.backend == "tensorrt": - warnAndLog( - "TensorRT is not implemented for GIMM yet, falling back to PyTorch" - ) - self.torchUtils.sync_stream(self.prepareStream) # type: ignore - - @torch.inference_mode() - def __call__( - self, - img1, - transition=False, - ): # type: ignore - with self.torchUtils.run_stream(self.stream): # type: ignore - if self.frame0 is None: - self.frame0 = self.torchUtils.frame_to_tensor(img1, self.prepareStream, self.device, self.dtype) - return - frame1 = self.torchUtils.frame_to_tensor(img1, self.prepareStream, self.device, self.dtype) - for n in range(self.ceilInterpolateFactor - 1): - if not transition: - timestep = (n + 1) * 1.0 / (self.ceilInterpolateFactor) - coord = self.coordDict[timestep] - timestep_tens = self.timestepDict[timestep] - xs = torch.cat( - (self.frame0.unsqueeze(2), frame1.unsqueeze(2)), dim=2 - ).to(self.device, non_blocking=True, dtype=self.dtype) - - while self.flownet is None: - sleep(1) - with torch.autocast(enabled=True, device_type=self.device.type): - output = self.flownet( - xs, coord, timestep_tens, ds_factor=self.scale - ) - - if torch.isnan(output).any(): - # if there are nans in output, reload with float32 precision and process.... dumb fix but whatever - raise ValueError("Nans in output") - - output = self.torchUtils.tensor_to_frame(output) - yield output - - else: - yield img1 - - self.torchUtils.copy_tensor(self.frame0, frame1, self.copyStream) - - self.torchUtils.sync_all_streams() diff --git a/backend/src/pytorch/InterpolateGMFSS.py b/backend/src/pytorch/InterpolateGMFSS.py index 9c8414a99..76c21750c 100644 --- a/backend/src/pytorch/InterpolateGMFSS.py +++ b/backend/src/pytorch/InterpolateGMFSS.py @@ -49,6 +49,8 @@ def __init__( self.hdr_mode = hdr_mode # used in base interpolate class (ik inheritance is bad leave me alone) self.dynamicScaledOpticalFlow = dynamicScaledOpticalFlow self.UHDMode = UHDMode + self.gpu_id = gpu_id + self.CompareNet = None self.max_timestep = max_timestep if UHDMode: @@ -74,8 +76,8 @@ def __init__( @torch.inference_mode() def _load(self): - self.stream = self.torchUtils.init_stream() - self.prepareStream = self.torchUtils.init_stream() + self.stream = self.torchUtils.init_stream(gpu_id=self.gpu_id) + self.prepareStream = self.torchUtils.init_stream(gpu_id=self.gpu_id) with self.torchUtils.run_stream(self.prepareStream): # type: ignore if self.dynamicScaledOpticalFlow: from ..utils.SSIM import SSIM diff --git a/backend/src/pytorch/InterpolateIFRNET.py b/backend/src/pytorch/InterpolateIFRNET.py index 7158d2779..4f045fd23 100644 --- a/backend/src/pytorch/InterpolateIFRNET.py +++ b/backend/src/pytorch/InterpolateIFRNET.py @@ -44,6 +44,7 @@ def __init__( self.ensemble = ensemble self.hdr_mode = hdr_mode # used in base interpolate class (ik inheritance is bad leave me alone) self.UHDMode = UHDMode + self.gpu_id = gpu_id self.CompareNet = None self.max_timestep = max_timestep if UHDMode: @@ -82,8 +83,8 @@ def __init__( @torch.inference_mode() def _load(self): - self.stream = self.torchUtils.init_stream() - self.prepareStream = self.torchUtils.init_stream() + self.stream = self.torchUtils.init_stream(self.gpu_id) + self.prepareStream = self.torchUtils.init_stream(self.gpu_id) with self.torchUtils.run_stream(self.prepareStream): # type: ignore from .InterpolateArchs.IFRNET.IFRNet import IFRNet diff --git a/backend/src/pytorch/InterpolateRIFE.py b/backend/src/pytorch/InterpolateRIFE.py index 90088ca13..890f15618 100644 --- a/backend/src/pytorch/InterpolateRIFE.py +++ b/backend/src/pytorch/InterpolateRIFE.py @@ -392,11 +392,11 @@ def _load(self): flownet_engine = trtHandler.build_engine(self.flownet, self.dtype, self.device, flownet_inputs, trt_engine_name=base_trt_engine_name, trt_multi_precision_engine=True, dynamic_shapes=flownet_dynamic_shapes,) trtHandler.save_engine(flownet_engine, base_trt_engine_name, flownet_inputs) - torch.cuda.empty_cache() + TorchUtils.clear_cache() if self.encode: encode_engine = trtHandler.build_engine(self.encode, self.dtype, self.device, encode_inputs, trt_engine_name=encode_trt_engine_name, trt_multi_precision_engine=True, dynamic_shapes=encode_dynamic_shapes,) trtHandler.save_engine(encode_engine, encode_trt_engine_name, encode_inputs) - torch.cuda.empty_cache() + TorchUtils.clear_cache() self.flownet = trtHandler.load_engine(base_trt_engine_name) if self.encode: diff --git a/backend/src/pytorch/InterpolateTorch.py b/backend/src/pytorch/InterpolateTorch.py index 123a533ba..70eab51e8 100644 --- a/backend/src/pytorch/InterpolateTorch.py +++ b/backend/src/pytorch/InterpolateTorch.py @@ -7,7 +7,6 @@ # from backend.src.pytorch.InterpolateArchs.GIMM import GIMM from .InterpolateArchs.DetectInterpolateArch import ArchDetect -from .InterpolateGIMM import InterpolateGIMMTorch from .InterpolateGMFSS import InterpolateGMFSSTorch from .InterpolateRIFE import InterpolateRifeTorch, InterpolateRIFEDRBA from .InterpolateIFRNET import InterpolateIFRNetTorch @@ -25,7 +24,5 @@ def build_interpolation_method(interpolate_model_path, backend, drba=False): return InterpolateRifeTorch case "gmfss": return InterpolateGMFSSTorch - case "gimm": - return InterpolateGIMMTorch case "ifrnet": return InterpolateIFRNetTorch # IFRNet is a RIFE based architecture diff --git a/backend/src/pytorch/TensorRTHandler.py b/backend/src/pytorch/TensorRTHandler.py index 3fcc759f3..5498f97d1 100644 --- a/backend/src/pytorch/TensorRTHandler.py +++ b/backend/src/pytorch/TensorRTHandler.py @@ -33,6 +33,7 @@ import tensorrt as trt from torch._export.converter import TS2EPConverter from torch.export.exported_program import ExportedProgram + from .TorchUtils import TorchUtils def torchscript_to_dynamo( model: torch.nn.Module, example_inputs: list[torch.Tensor] @@ -43,7 +44,7 @@ def torchscript_to_dynamo( module, sample_args=tuple(example_inputs), sample_kwargs=None ).convert() del module - torch.cuda.empty_cache() + TorchUtils.clear_cache() return exported_program def nnmodule_to_dynamo( @@ -145,7 +146,7 @@ def build_engine( """ start_time = time.time() trt_engine_name += self.trt_path_appendix - torch.cuda.empty_cache() + TorchUtils.clear_cache() """Builds a TensorRT engine from the provided model.""" print( f"Building TensorRT engine {os.path.basename(trt_engine_name)}. This may take a while...", @@ -155,7 +156,7 @@ def build_engine( with suppress_stdout_stderr(): exported_program = nnmodule_to_dynamo(model, example_inputs, dynamic_shapes=dynamic_shapes) - torch.cuda.empty_cache() + TorchUtils.clear_cache() exported_program = self.grid_sample_decomp(exported_program) @@ -178,7 +179,7 @@ def build_engine( f"TensorRT engine built in {time.time() - start_time:.2f} seconds.", file=sys.stderr, ) - torch.cuda.empty_cache() + TorchUtils.clear_cache() return model_trt def save_engine(self, trt_engine: torch.jit.ScriptModule, trt_engine_name: str, example_inputs: list[torch.Tensor]): @@ -192,7 +193,7 @@ def save_engine(self, trt_engine: torch.jit.ScriptModule, trt_engine_name: str, output_format="torchscript", inputs=tuple(example_inputs), ) - torch.cuda.empty_cache() + TorchUtils.clear_cache() def load_engine(self, trt_engine_name: str) -> torch.jit.ScriptModule: """Loads a TensorRT engine from the specified path.""" diff --git a/backend/src/pytorch/TorchUtils.py b/backend/src/pytorch/TorchUtils.py index eac700fc0..292e91d4e 100644 --- a/backend/src/pytorch/TorchUtils.py +++ b/backend/src/pytorch/TorchUtils.py @@ -7,7 +7,7 @@ backendDetect = BackendDetect() from ..utils.Util import ( - warnAndLog, + log, CudaChecker ) HAS_PYTORCH_CUDA = CudaChecker().HAS_PYTORCH_CUDA @@ -45,11 +45,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TorchUtils: # device and precision are in string formats, loaded straight from the command line arguments - def __init__(self, width, height, device_type:str, hdr_mode=False, padding=None, ): + def __init__(self, width, height, device_type:str, hdr_mode=False, padding=None, gpu_id=0): self.width = width self.height = height self.hdr_mode = hdr_mode self.padding = padding + self.gpu_id = gpu_id if device_type == "auto": self.device_type = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "xpu" if torch.xpu.is_available() else "cpu" else: @@ -59,9 +60,8 @@ def __init__(self, width, height, device_type:str, hdr_mode=False, padding=None, del test_tensor self.use_numpy = True except Exception as e: - warnAndLog(f"Failed to create a Numpy tensor, This will heavily reduce performance.") + log(f"Failed to create a Numpy tensor, This will heavily reduce performance.") self.use_numpy = False - self.__init_stream_func = self.__init_stream_function() self.__run_stream_func = self.__run_stream_function() self.__sync_all_streams_func = self.__sync_all_streams_function() @@ -74,19 +74,21 @@ def __sync_all_streams_function(self): return dummy_function # CPU does not require explicit synchronization if self.device_type == "xpu": return torch.xpu.synchronize - return lambda: warnAndLog(f"Unknown device type {self.device_type}, skipping stream synchronization.") - - def __init_stream_function(self)-> callable: + return lambda: log(f"Unknown device type {self.device_type}, skipping stream synchronization.") + + def init_stream(self, gpu_id = 0) -> torch.Stream: """ Initializes the stream based on the device type. """ + log(f"Initializing stream for device {self.device_type} (GPU ID: {gpu_id})") + device = self.handle_device(self.device_type, gpu_id) if self.device_type == "cuda": - return torch.cuda.Stream + return torch.cuda.Stream(device=device) elif self.device_type == "xpu": - return torch.xpu.Stream + return torch.xpu.Stream(device=device) else: - return DummyContextManager # For CPU and MPS, we can use a dummy stream - + return DummyContextManager() # For CPU and MPS, we can use a dummy stream + def __run_stream_function(self) -> callable: """ Runs the stream based on the device type. @@ -97,10 +99,8 @@ def __run_stream_function(self) -> callable: return torch.xpu.stream else: return dummy_context_manager # For CPU and MPS, we can use a dummy context manager - - def init_stream(self): - return self.__init_stream_func() - + + def run_stream(self, stream): return self.__run_stream_func(stream) @@ -113,7 +113,7 @@ def sync_stream(self, stream: torch.Stream): case "cpu": pass # CPU does not require explicit synchronization case _: - warnAndLog(f"Unknown device type {self.device_type}, skipping stream synchronization.") + log(f"Unknown device type {self.device_type}, skipping stream synchronization.") # For other devices, we assume no synchronization is needed. def sync_all_streams(self): @@ -126,7 +126,8 @@ def sync_all_streams(self): def handle_device(device, gpu_id: int = 0) -> torch.device: """ returns device based on gpu id and device parameter - """ + """ + log(f"Handling device: {device}, GPU ID: {gpu_id}") if device == "auto": if torch.cuda.is_available(): torchdevice = torch.device("cuda", gpu_id) @@ -146,6 +147,7 @@ def handle_device(device, gpu_id: int = 0) -> torch.device: @staticmethod def handle_precision(precision) -> torch.dtype: + log(f"Handling precision: {precision}") if precision == "auto": return torch.float16 if backendDetect.get_half_precision() else torch.float32 if precision == "float32": @@ -194,6 +196,8 @@ def frame_to_tensor(self, frame, stream: torch.Stream, device: torch.device, dty def clear_cache(): if HAS_PYTORCH_CUDA: torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() @torch.inference_mode() def tensor_to_frame(self, frame: torch.Tensor): diff --git a/backend/src/pytorch/UpscaleTorch.py b/backend/src/pytorch/UpscaleTorch.py index 107fee5a4..4f316bbe2 100644 --- a/backend/src/pytorch/UpscaleTorch.py +++ b/backend/src/pytorch/UpscaleTorch.py @@ -3,6 +3,7 @@ import gc from .TorchUtils import TorchUtils +from .spandrel.UpscaleModelWrapper import UpscaleModelWrapper import torch as torch import torch.nn.functional as F import sys @@ -89,9 +90,9 @@ def __init__( ): self.torchUtils = TorchUtils(width=width, height=height,hdr_mode=hdr_mode,device_type=device) - device = self.torchUtils.handle_device(device, gpu_id) - self.tile_pad = tile_pad + device = self.torchUtils.handle_device(device, gpu_id=gpu_id) self.dtype = self.torchUtils.handle_precision(precision) + self.tile_pad = tile_pad self.device = device self.videoWidth = width self.videoHeight = height @@ -99,7 +100,6 @@ def __init__( self.tile = [self.tilesize, self.tilesize] self.modelPath = modelPath self.backend = backend - self.trt_workspace_size = trt_workspace_size self.trt_optimization_level = trt_optimization_level @@ -111,17 +111,12 @@ def __init__( self.trt_static_shape = trt_static_shape # streams - self.stream = self.torchUtils.init_stream() - self.f2tstream = self.torchUtils.init_stream() - self.prepareStream = self.torchUtils.init_stream() - self.convertStream = self.torchUtils.init_stream() + self.stream = self.torchUtils.init_stream(gpu_id=gpu_id) + self.f2tstream = self.torchUtils.init_stream(gpu_id=gpu_id) + self.prepareStream = self.torchUtils.init_stream(gpu_id=gpu_id) + self.convertStream = self.torchUtils.init_stream(gpu_id=gpu_id) self._load() - - # add prores - # add janai v2 - # add bhi light model - @torch.inference_mode() def _load(self): @@ -140,9 +135,13 @@ def _load(self): with self.torchUtils.run_stream(self.prepareStream): - self.set_self_model(backend="pytorch") + self.upscale_model_wrapper = UpscaleModelWrapper( + model_path=self.modelPath, + device=self.device, + precision=self.dtype, + ) - match self.scale: + match self.upscale_model_wrapper.get_scale(): case 1: modulo = 4 case 2: @@ -223,24 +222,9 @@ def _load(self): dim_width = _width * modulo dynamic_shapes = {"x": {2: dim_height, 3: dim_width}} - - - # inference and get re-load state dict due to issue with span. - try: - model = self.model - model(inputs[0]) - self.model.load_state_dict(model.state_dict()) - output = model(inputs[0]) - del model - torch.cuda.empty_cache() - except Exception as e: - print("Test inf failed") - - - try: trt_engine = trtHandler.build_engine( - self.model, + self.upscale_model_wrapper.get_model(), self.dtype, self.device, example_inputs=inputs, @@ -266,7 +250,7 @@ def _load(self): else: trt_engine = trtHandler.build_engine( - self.model, + self.upscale_model_wrapper.get_model(), self.dtype, self.device, example_inputs=inputs, @@ -283,16 +267,15 @@ def _load(self): raise RuntimeError( f"Failed to build TensorRT engine: {e}\n" ) - self.set_self_model(backend="tensorrt", trt_engine_name=self.trt_engine_name) - - + model = trtHandler.load_engine(trt_engine_name=self.trt_engine_name) + self.upscale_model_wrapper.load_model(model) self.torchUtils.clear_cache() self.torchUtils.sync_all_streams() @torch.inference_mode() def hotUnload(self): - self.model = None + self.upscale_model_wrapper = None gc.collect() self.torchUtils.clear_cache() if HAS_PYTORCH_CUDA: @@ -302,67 +285,18 @@ def hotUnload(self): @torch.inference_mode() def hotReload(self): self._load() - - @torch.inference_mode() - def set_self_model(self, backend="pytorch", trt_engine_name=None): - torch.cuda.empty_cache() - if backend == "tensorrt": - from .TensorRTHandler import TorchTensorRTHandler - trtHandler = TorchTensorRTHandler(model_parent_path=os.path.dirname(self.modelPath),) - self.model = trtHandler.load_engine(trt_engine_name=trt_engine_name) - else: - self.model = self.loadModel( - modelPath=self.modelPath, device=self.device, dtype=self.dtype - ) - - @torch.inference_mode() - def loadModel( - self, modelPath: str, dtype: torch.dtype = torch.float32, device: str = "cuda" - ) -> torch.nn.Module: - try: - from .spandrel import ModelLoader, ImageModelDescriptor, UnsupportedModelError - except ImportError: - # spandrel will import like this if its a submodule - from .spandrel.libs.spandrel.spandrel import ModelLoader, ImageModelDescriptor, UnsupportedModelError - try: - model = ModelLoader().load_from_file(modelPath) - assert isinstance(model, ImageModelDescriptor) - self.model = model - # get model attributes - - except (UnsupportedModelError) as e: - from .VSRArchs.AnimeSR import AnimeSRArch - from .VSRArchs.vsr_inference_helper import VSRInferenceHelper - model = AnimeSRArch() - self.inference = VSRInferenceHelper(model) - - self.scale = model.scale - model = model.model - - model.load_state_dict(model.state_dict(), assign=True) - model.eval().to(self.device, dtype=self.dtype) - try: - example_input = torch.zeros((1, 3, 64, 64), device=self.device, dtype=self.dtype) - model(example_input) - except Exception as e: - print("Error occured during model validation, falling back to float32 dtype.\n") - log(str(e)) - model = model.to(self.device, dtype=torch.float32) - - return model - @torch.inference_mode() def __call__(self, image: bytes) -> torch.Tensor: image = self.torchUtils.frame_to_tensor(image, self.f2tstream, self.device, self.dtype) with self.torchUtils.run_stream(self.stream), torch.amp.autocast( enabled=self.dtype == torch.float16,device_type="cuda"): - while self.model is None: + while self.upscale_model_wrapper is None: sleep(1) if self.tilesize == 0: - - output = self.model(image) - + + output = self.upscale_model_wrapper(image) + else: output = self.renderTiledImage(image) @@ -375,7 +309,7 @@ def __call__(self, image: bytes) -> torch.Tensor: return output def getScale(self): - return self.scale + return self.upscale_model_wrapper.get_scale() @torch.inference_mode() def renderTiledImage( @@ -431,7 +365,7 @@ def renderTiledImage( ) # process tile - output_tile = self.model( + output_tile = self.upscale_model_wrapper( input_tile ) diff --git a/backend/src/pytorch/spandrel/UpscaleModelWrapper.py b/backend/src/pytorch/spandrel/UpscaleModelWrapper.py new file mode 100644 index 000000000..f496fc238 --- /dev/null +++ b/backend/src/pytorch/spandrel/UpscaleModelWrapper.py @@ -0,0 +1,72 @@ +import torch +from ...utils.Util import log +from ..TorchUtils import TorchUtils +import os + +class UpscaleModelWrapper: + def __init__(self, model_path: torch.nn.Module, device: torch.device, precision: torch.dtype): + self.__model_path = model_path + self.__device = device + self.__precision = precision + self.load_model() + self.set_precision(self.__precision) + self.__test_model_precision() + + def set_precision(self, precision: torch.dtype): + self.__precision = precision + self.__model.to(self.__device, dtype=precision) + + def get_model(self): + return self.__model + + def get_scale(self): + return self.__scale + + def load_state_dict(self, state_dict): + self.__model.load_state_dict(state_dict) + + def __test_inference(self, test_input:torch.Tensor): + # inference and get re-load state dict due to issue with span. + with torch.inference_mode(): + model = self.__model + model(test_input) + output = model(test_input) + self.__model.load_state_dict(model.state_dict()) # reload state dict to fix span + del model + TorchUtils.clear_cache() + + def __test_model_precision(self): + test_input = torch.randn(1, 3, 64, 64).to(self.__device, dtype=self.__precision) + with torch.inference_mode(): + try: + self.__test_inference(test_input) + except Exception as e: + log(f"Model precision {self.__precision} not supported, falling back to float32: {e}") + self.set_precision(torch.float32) + self.__test_inference(test_input) + + + @torch.inference_mode() + def load_model(self, model=None) -> torch.nn.Module: + if not model: + try: + from . import ModelLoader, ImageModelDescriptor, UnsupportedModelError + except ImportError: + # spandrel will import like this if its a submodule + from .libs.spandrel.spandrel import ModelLoader, ImageModelDescriptor, UnsupportedModelError + try: + model = ModelLoader().load_from_file(self.__model_path) + assert isinstance(model, ImageModelDescriptor) + # get model attributes + + except (UnsupportedModelError) as e: + log(f"Model at {self.__model_path} is not supported: {e}") + raise e + + self.__scale = model.scale + model = model.model + + self.__model = model + + def __call__(self, *args, **kwargs): + return self.__model(*args, **kwargs) \ No newline at end of file diff --git a/backend/src/utils/BackendDetect.py b/backend/src/utils/BackendDetect.py index 85f8a59ad..05059016c 100644 --- a/backend/src/utils/BackendDetect.py +++ b/backend/src/utils/BackendDetect.py @@ -42,7 +42,7 @@ def __init__(self): def __get_pytorch_device(self): if "cu" in self.__torch.__version__: return "cuda" - if "rocm" in self.__torch.__version__: return "cuda" + if "rocm" in self.__torch.__version__: return "rocm" if self.__torch.xpu.is_available(): return "xpu" if self.__torch.backends.mps.is_available(): return "mps" return "CPU" @@ -59,7 +59,7 @@ def get_half_precision(self): """ try: - x = self.__torch.tensor([1.0], dtype=self.__torch.float16).to(device=self.pytorch_device) + x = self.__torch.tensor([1.0], dtype=self.__torch.float16).to(device="cuda" if self.pytorch_device == "rocm" else self.pytorch_device) return True except Exception as e: log(str(e)) @@ -78,6 +78,7 @@ def get_gpus_torch(self): torch_cmd_dict = { "cuda": self.__torch.cuda, "xpu": self.__torch.xpu, + "rocm": self.__torch.cuda, } torch_cmd = torch_cmd_dict[self.pytorch_device] diff --git a/backend/src/utils/Util.py b/backend/src/utils/Util.py index 40778f38b..3f9213a16 100644 --- a/backend/src/utils/Util.py +++ b/backend/src/utils/Util.py @@ -59,11 +59,14 @@ def errorAndLog(message: str): -def log(message: str): +def log(message: str, show_backend = True): """ Log is now depricated, just using print now. """ - print("BACKEND: " + message, file=sys.stderr) + if show_backend: + print("BACKEND: " + message, file=sys.stderr) + else: + print(message, file=sys.stderr) #message = message + "\n\n\n\n" + "-" * len(message) #print(message, file=sys.stderr) diff --git a/backend/src/utils/VideoInfo.py b/backend/src/utils/VideoInfo.py index c5df06532..7c9babe48 100644 --- a/backend/src/utils/VideoInfo.py +++ b/backend/src/utils/VideoInfo.py @@ -101,6 +101,7 @@ class FFMpegInfoWrapper(VideoInfo): def __init__(self, input_file: str): self.input_file = input_file self.stream_line = None + self.stream_line_2 = None self._get_ffmpeg_info() def _get_ffmpeg_info(self): @@ -124,9 +125,14 @@ def _get_ffmpeg_info(self): for line in self.ffmpeg_output_raw.split("\n"): if "Stream #" in line and "Video" in line: self.stream_line = line + self.ffmpeg_output_raw = self.ffmpeg_output_raw.replace(line, "") + break + + for line in self.ffmpeg_output_raw.split("\n"): + if "Stream #" in line and "Video" in line: + self.stream_line_2 = line + self.ffmpeg_output_raw = self.ffmpeg_output_raw.replace(line, "") break - - if self.stream_line is None: log("No video stream found in the input file.") except Exception: @@ -157,18 +163,25 @@ def get_fps(self) -> float: def check_color_opt(self, color_opt:str) -> str | None: if self.stream_line: + if "ffv1" in self.get_codec(): + string_pattern = "1," + else: + string_pattern = ")," try: match color_opt: case "Space": - color_opt_detected = self.stream_line.split("),")[1].split(",")[1].split("/")[0].strip() + color_opt_detected = self.stream_line_2.split(",")[1].split("(")[1].strip() if color_opt_detected not in FFMPEG_COLORSPACES: - return None + color_opt_detected = self.stream_line.split(string_pattern)[1].split(",")[1].split("/")[0].strip() + if color_opt_detected not in FFMPEG_COLORSPACES: + return None + case "Primaries": - color_opt_detected = self.stream_line.split("),")[1].split("/")[1].strip() + color_opt_detected = self.stream_line.split(string_pattern)[1].split("/")[1].strip() if color_opt_detected not in FFMPEG_COLOR_PRIMARIES: return None case "Transfer": - color_opt_detected = self.stream_line.split("),")[1].split("/")[2].replace(")","").split(",")[0].strip() + color_opt_detected = self.stream_line.split(string_pattern)[1].split("/")[2].replace(")","").split(",")[0].strip() if color_opt_detected not in FFMPEG_COLOR_TRC: return None @@ -317,7 +330,7 @@ def print_video_info(video_info: VideoInfo): __all__ = ["FFMpegInfoWrapper", "OpenCVInfo", "print_video_info"] if __name__ == "__main__": - video_path = "/home/pax/Downloads/Life Untouched 4K Demo.mp4" + video_path = "/home/pax/Downloads/ffv1_youtube_test2.mkv" #video_path = "/home/pax/Documents/test/LG New York HDR UHD 4K Demo.ts" #video_path = "/home/pax/Documents/test/out.mkv" #video_path = "/home/pax/Videos/TVアニメ「WIND BREAKER Season 2」ノンクレジットオープニング映像「BOYZ」SixTONES [AWlUVr7Du04]_gmfss-pro_deh264-span_janai-v2_72.0fps_3840x2160.mkv" diff --git a/backend/src/version.py b/backend/src/version.py index 1bd32dd32..097b2fa70 100644 --- a/backend/src/version.py +++ b/backend/src/version.py @@ -1,2 +1,2 @@ -__version__ = "2.3.7-dev10" # this is the version of the backend, this is compared to the version of the front end. need to be the same +__version__ = "2.3.8-dev13" # this is the version of the backend, this is compared to the version of the front end. need to be the same diff --git a/build.py b/build.py index 28b2ef1a3..2233cebac 100644 --- a/build.py +++ b/build.py @@ -30,9 +30,9 @@ def set_mainwindow_size_zero(path="testRVEInterface.ui"): width = geometry.find("width") height = geometry.find("height") if width is not None: - width.text = "0" + width.text = "1000" if height is not None: - height.text = "0" + height.text = "700" tree.write(path) set_mainwindow_size_zero() diff --git a/src/GenerateFFMpegCommand.py b/src/GenerateFFMpegCommand.py index 94833fc59..274af7186 100644 --- a/src/GenerateFFMpegCommand.py +++ b/src/GenerateFFMpegCommand.py @@ -37,7 +37,7 @@ def build_command(self): self._color_transfer, ] encoder_params += f":transfer={self._color_transfer}:" - if self._color_space is not None and self._video_pixel_format != "yuv420p": + if self._color_space is not None: command += [ "-colorspace", self._color_space, diff --git a/src/ModelHandler.py b/src/ModelHandler.py index f466ed927..8b9a39a62 100644 --- a/src/ModelHandler.py +++ b/src/ModelHandler.py @@ -62,12 +62,7 @@ 1, "gmfss", ), - "GIMM (Slow Model, Realistic/General)": ( - "GIMMVFI_RAFT.pth", - "GIMMVFI_RAFT.pth", - 1, - "gimm", - ), + "IFRNet (Fast Model, Realistic only)": ( "IFRNet_Vimeo90K.pth", "IFRNet_Vimeo90K.pth", @@ -609,9 +604,9 @@ model_path = os.path.join(CUSTOM_MODELS_PATH, model) if os.path.exists(model_path): if not os.path.isfile(model_path): - customNCNNUpscaleModels[model] = (model, model, 1, "custom") + customNCNNUpscaleModels[model] = (model, model, 4, "custom") if model.endswith(".pth") or model.endswith(".safetensors"): - customPytorchUpscaleModels[model] = (model, model, 1, "custom") + customPytorchUpscaleModels[model] = (model, model, 4, "custom") pytorchUpscaleModels = pytorchUpscaleModels | customPytorchUpscaleModels diff --git a/src/Util.py b/src/Util.py index e64f9cad7..4d502ee6f 100644 --- a/src/Util.py +++ b/src/Util.py @@ -134,7 +134,44 @@ def getUnusedFileName(base_file_name: str, outputDirectory: str, extension: str) ) iteration += 1 return output_file + @staticmethod + def getDefaultOutputFolder() -> str: + """ + Returns the default output folder based on the operating system. + """ + videos_folder = os.path.join(f"{HOME_PATH}", "Videos") + + if PLATFORM == "linux": + try: + result = subprocess.run( + ["xdg-user-dir", "VIDEOS"], capture_output=True, text=True + ).stdout.strip() + if os.path.isdir(result): + videos_folder = result + except Exception as e: + log(f"An error occurred while getting the Videos folder on Linux: {e}") + + if PLATFORM == "win32": + try: + import ctypes + from ctypes import wintypes + + CSIDL_MYVIDEO = 0x000e + SHGFP_TYPE_CURRENT = 0 + + buf = ctypes.create_unicode_buffer(wintypes.MAX_PATH) + ctypes.windll.shell32.SHGetFolderPathW( + None, CSIDL_MYVIDEO, None, SHGFP_TYPE_CURRENT, buf + ) + if os.path.isdir(buf.value): + videos_folder = buf.value + except Exception as e: + log(f"An error occurred while getting the Videos folder on Windows: {e}") + + if PLATFORM == "darwin": + videos_folder = os.path.join(f"{HOME_PATH}", "Desktop") + return videos_folder def log(message: str): diff --git a/src/ui/DownloadTab.py b/src/ui/DownloadTab.py index 08b680482..b32d2002d 100644 --- a/src/ui/DownloadTab.py +++ b/src/ui/DownloadTab.py @@ -1,46 +1,12 @@ import os -from PySide6.QtWidgets import QMainWindow -from .QTcustom import RegularQTPopup, NetworkCheckPopup, addNotificationToButton, remove_combobox_item_by_text +from PySide6.QtWidgets import QMainWindow, QMessageBox +from .QTcustom import RegularQTPopup, NetworkCheckPopup, remove_combobox_item_by_text from ..DownloadDeps import DownloadDependencies -from ..DownloadModels import DownloadModel -from ..ModelHandler import ( - ncnnInterpolateModels, - pytorchInterpolateModels, - ncnnUpscaleModels, - pytorchUpscaleModels, -) -import os - - -from PySide6.QtWidgets import QMessageBox - -from PySide6.QtWidgets import QMessageBox from .Updater import ApplicationUpdater -from ..constants import IS_FLATPAK, MODELS_PATH, PLATFORM, CWD, USE_LOCAL_BACKEND, HOME_PATH, BACKEND_PATH, PYTHON_EXECUTABLE_PATH, PYTHON_DIRECTORY, PLATFORM, IS_FLATPAK, CWD, CPU_ARCH +from ..constants import IS_FLATPAK, PLATFORM, CWD, USE_LOCAL_BACKEND, HOME_PATH, PLATFORM, IS_FLATPAK, CWD, CPU_ARCH from ..BuiltInTorchVersions import TorchVersion from ..Util import FileHandler -def downloadModelsBasedOnInstalledBackend(installed_backends: list): - if NetworkCheckPopup(): - for backend in installed_backends: - match backend: - case "ncnn": - for model in ncnnInterpolateModels: - DownloadModel( - model, ncnnInterpolateModels[model][1], MODELS_PATH - ) - for model in ncnnUpscaleModels: - DownloadModel(model, ncnnUpscaleModels[model][1], MODELS_PATH) - case "pytorch": # no need for tensorrt as it uses pytorch models - for model in pytorchInterpolateModels: - DownloadModel( - model, pytorchInterpolateModels[model][1], MODELS_PATH - ) - for model in pytorchUpscaleModels: - DownloadModel( - model, pytorchUpscaleModels[model][1], MODELS_PATH - ) - class DownloadTab: def __init__( @@ -115,14 +81,7 @@ def QButtonConnect(self): self.parent.downloadDirectMLBtn.clicked.connect( lambda: self.download("directml", True) ) - self.parent.downloadAllModelsBtn.clicked.connect( - lambda: downloadModelsBasedOnInstalledBackend( - ["ncnn", "pytorch", "tensorrt", "directml"] - ) - ) - self.parent.downloadSomeModelsBasedOnInstalledBackendbtn.clicked.connect( - lambda: downloadModelsBasedOnInstalledBackend(self.backends) - ) + self.parent.uninstallNCNNBtn.clicked.connect( lambda: self.download("ncnn", False) ) diff --git a/src/ui/ProcessTab.py b/src/ui/ProcessTab.py index 912f304c8..4b20c548d 100644 --- a/src/ui/ProcessTab.py +++ b/src/ui/ProcessTab.py @@ -104,6 +104,16 @@ def onTilingSwitch(self): self.tileUpAnimationHandler.moveUpAnimation(self.parent.tileSizeContainer) self.parent.tileSizeContainer.setVisible(False) + def openOutputFolderInExplorer(self): + output_folder = os.path.dirname(self.parent.outputFileText.text()) + if os.path.isdir(output_folder): + if PLATFORM == "win32": + os.startfile(output_folder) + elif PLATFORM == "darwin": + subprocess.Popen(["open", output_folder]) + else: + subprocess.Popen(["xdg-open", output_folder]) + def QConnect(self): # connect file select buttons self.parent.addToRenderQueueButton.clicked.connect(self.parent.addToRenderQueue) @@ -120,6 +130,7 @@ def QConnect(self): self.parent.batchSelectButton.clicked.connect(self.parent.openBatchFiles) self.parent.inputFileText.textChanged.connect(self.parent.loadVideo) self.parent.outputFileSelectButton.clicked.connect(self.parent.openOutputFolder) + self.parent.openOutputFolderButton.clicked.connect(self.openOutputFolderInExplorer) # connect render button self.parent.startRenderButton.clicked.connect(self.parent.startRender) # set tile size visible to false by default diff --git a/src/ui/SettingsTab.py b/src/ui/SettingsTab.py index 1ed9c037f..8c379cc5d 100644 --- a/src/ui/SettingsTab.py +++ b/src/ui/SettingsTab.py @@ -2,7 +2,7 @@ from PySide6.QtWidgets import QMainWindow, QFileDialog from ..constants import PLATFORM, HOME_PATH -from ..Util import currentDirectory, checkForWritePermissions, open_folder, log +from ..Util import currentDirectory, checkForWritePermissions, open_folder, log, FileHandler from .QTcustom import RegularQTPopup from ..GenerateFFMpegCommand import FFMpegCommand from ..VideoInfo import VideoLoader @@ -21,6 +21,8 @@ def __init__( self.color_space = None self.color_primaries = None self.color_transfer = None + self.in_pix_fmt = "" + self.hdr_mode = False self.ffmpeg_settings_dict = { "encoder": self.parent.encoder, "audio_encoder": self.parent.audio_encoder, @@ -54,9 +56,16 @@ def updateFFMpegCommand(self): for key, value in self.ffmpeg_settings_dict.items(): self.settings.writeSetting(key, value.currentText()) - pixel_fmt = self.settings.settings['video_pixel_format'] - - hdr_mode = False + self.out_pixel_fmt = self.settings.settings['video_pixel_format'] + pxfmtDict = { + "yuv420p": "yuv420p", + "yuv422p": "yuv422p", + "yuv444p": "yuv444p", + "yuv420p (10 bit)": "yuv420p10le", + "yuv422p (10 bit)": "yuv422p10le", + "yuv444p (10 bit)": "yuv444p10le", + } + self.out_pixel_fmt = pxfmtDict[self.out_pixel_fmt] input_file = self.parent.inputFileText.text() if input_file and len(input_file) > 1: # caching is nice @@ -65,36 +74,38 @@ def updateFFMpegCommand(self): self.ffmpegInfoWrapper = VideoLoader(self.input_file) self.ffmpegInfoWrapper.loadVideo() self.ffmpegInfoWrapper.getData() + self.hdr_mode = (self.ffmpegInfoWrapper.is_hdr) and self.settings.settings['auto_hdr_mode'] == "True" + self.color_space = self.ffmpegInfoWrapper.color_space + self.color_primaries = self.ffmpegInfoWrapper.color_primaries + self.color_transfer = self.ffmpegInfoWrapper.color_transfer + self.in_pix_fmt = self.ffmpegInfoWrapper.pixel_format + + + if self.hdr_mode or ("10" in self.in_pix_fmt and self.settings.settings['auto_hdr_mode'] == "True"): + pxfmtDict = { + "yuv420p": "yuv420p10le", + "yuv422p": "yuv422p10le", + "yuv444p": "yuv444p10le", + } + + if self.out_pixel_fmt in pxfmtDict: + self.out_pixel_fmt = pxfmtDict[self.out_pixel_fmt] - if self.ffmpegInfoWrapper: - hdr_mode = (self.ffmpegInfoWrapper.is_hdr) and self.settings.settings['auto_hdr_mode'] == "True" - if hdr_mode: - pxfmtdict = { - "yuv420p": "yuv420p10le", - "yuv422": "yuv422p10le", - "yuv444": "yuv444p10le", - } - - if pixel_fmt in pxfmtdict: - pixel_fmt = pxfmtdict[pixel_fmt] - self.color_space = self.ffmpegInfoWrapper.color_space - self.color_primaries = self.ffmpegInfoWrapper.color_primaries - self.color_transfer = self.ffmpegInfoWrapper.color_transfer command = FFMpegCommand( - self.settings.settings['encoder'].replace(' (experimental)', '').replace(' (40 series and up)', ''), - self.settings.settings['video_encoder_speed'], - self.settings.settings['video_quality'], - pixel_fmt, - self.settings.settings['audio_encoder'], - self.settings.settings['audio_bitrate'], - hdr_mode, - self.color_space, - self.color_primaries, - self.color_transfer, - ).build_command() - self.parent.EncoderCommand.setText(" ".join(command), - ) + self.settings.settings['encoder'].replace(' (experimental)', '').replace(' (40 series and up)', ''), + self.settings.settings['video_encoder_speed'], + self.settings.settings['video_quality'], + self.out_pixel_fmt, + self.settings.settings['audio_encoder'], + self.settings.settings['audio_bitrate'], + self.hdr_mode, + self.color_space if self.in_pix_fmt != "yuv420p" else None, + self.color_primaries, + self.color_transfer, + ).build_command() + self.parent.EncoderCommand.setText(" ".join(command)) + self.parent.updateVideoGUIText() def connectWriteSettings(self): @@ -361,6 +372,7 @@ def __init__(self): The default settings are set here, and are overwritten by the settings in the settings file if it exists and the legnth of the settings is the same as the default settings. The key is equal to the name of the widget of the setting in the settings tab. """ + output_folder_default = FileHandler.getDefaultOutputFolder() self.defaultSettings = { "precision": "auto", "tensorrt_optimization_level": "3", @@ -375,12 +387,8 @@ def __init__(self): "scene_change_detection_threshold": "3.5", "discord_rich_presence": "False", "video_quality": "High", - "output_folder_location": os.path.join(f"{HOME_PATH}", "Videos") - if PLATFORM != "darwin" - else os.path.join(f"{HOME_PATH}", "Desktop"), - "last_input_folder_location": os.path.join(f"{HOME_PATH}", "Videos") - if PLATFORM != "darwin" - else os.path.join(f"{HOME_PATH}", "Desktop"), + "output_folder_location": output_folder_default, + "last_input_folder_location": output_folder_default, "uhd_mode": "True", "ncnn_gpu_id": "0", "pytorch_gpu_id": "0", diff --git a/src/version.py b/src/version.py index 85a982dc0..e439f0d2f 100644 --- a/src/version.py +++ b/src/version.py @@ -1,2 +1,2 @@ -version = "2.3.7" -backend_dev_version = "2.3.7-dev10" # has to match version of backend, update this wehenver updating pre release +version = "2.3.8" +backend_dev_version = "2.3.8-dev13" # has to match version of backend, update this wehenver updating pre release diff --git a/testRVEInterface.ui b/testRVEInterface.ui index 9f9574f81..a89bde3aa 100644 --- a/testRVEInterface.ui +++ b/testRVEInterface.ui @@ -5,8 +5,8 @@ 0 0 - 0 - 0 + 1000 + 700 @@ -225,7 +225,7 @@ QScrollBar:horizontal{ - 2 + 4 @@ -892,7 +892,7 @@ background-color: #1f232a - 500 + 600 16777215 @@ -934,7 +934,7 @@ background-color: #1f232a - 500 + 600 16777215 @@ -1089,9 +1089,9 @@ background-color:#343b47 0 - -454 - 470 - 878 + 0 + 516 + 984 @@ -1182,34 +1182,34 @@ background-color:#343b47 - - - - - 0 - 0 - - - - - 0 - 30 - - - - - 230 - 30 - - - - false - - - + + + + + 0 + 0 + + + + + 0 + 30 + + + + + 1000 + 30 + + + + false + + + @@ -1256,16 +1256,10 @@ background-color:#343b47 - + false - - - 0 - 0 - - 0 @@ -1274,15 +1268,64 @@ background-color:#343b47 - 230 + 16777215 30 + + + 12 + + + + + + + Open Output Folder + + + + + false + + + + 0 + 0 + + + + + 0 + 30 + + + + + 10000 + 30 + + + + + + + + false + + + background: white + + + Qt::Orientation::Horizontal + + + @@ -2159,52 +2202,67 @@ text-align: center; - - - false - + - + 0 0 - - - 0 - 190 - - - 100000 - 180 + 42344 + 200 - - - - - + + + + + false + + + + 0 + 0 + + + + + 0 + 190 + + + + + 100000 + 180 + + + + + + + QTextEdit:disabled{color:white;} - - - Qt::ScrollBarPolicy::ScrollBarAlwaysOff - - - Qt::ScrollBarPolicy::ScrollBarAlwaysOff - - - QAbstractScrollArea::SizeAdjustPolicy::AdjustToContents - - - QTextEdit::LineWrapMode::NoWrap - - - true - - - <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> + + + Qt::ScrollBarPolicy::ScrollBarAlwaysOff + + + Qt::ScrollBarPolicy::ScrollBarAlwaysOff + + + QAbstractScrollArea::SizeAdjustPolicy::AdjustToContents + + + QTextEdit::LineWrapMode::NoWrap + + + true + + + <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> <html><head><meta name="qrichtext" content="1" /><meta charset="utf-8" /><style type="text/css"> p, li { white-space: pre-wrap; } hr { height: 1px; border-width: 0; } @@ -2214,14 +2272,16 @@ li.checked::marker { content: "\2612"; } <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">FPS:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Resolution:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Frame Count:</span></p> -<p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Bitrate:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Encoder:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Container:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Color Space:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Pixel Format:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Is HDR:</span></p> <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:10pt;">Bit Depth:</span></p></body></html> - + + + + @@ -2233,238 +2293,41 @@ li.checked::marker { content: "\2612"; } - - - - 0 - 0 - - - - - 16777215 - 16777215 - - + - Render Queue + Encoder Options - + - - - - 0 - - - 0 - - - 0 - - - 0 - + + - - - - 0 - 0 - - - - - 0 - 0 - - - - - 335 - 16777215 - - - - QListWidget{ -background-color:#1f232a; -border-radius:10px; -} - - + + true - - - - - - - - - Remove - - - - - - - Move Up - - - - - - - Move Down - - - - - - - - - - - - - 0 - - - - - - 0 - 0 - - - - - 0 - 0 - - - - - 16777215 - 16777215 - - - - QListWidget{ -background-color:#1f232a; -border-radius:10px; -} - - - - - - - Finished Output Files: - - - - - - - - - - - 0 - 0 - - - - - 16777215 - 16777215 - - - - Advanced - - - - - - - 0 - 0 - - - - true - - - Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft|Qt::AlignmentFlag::AlignTop - - - - - 0 - 0 - 393 - 510 - - - - - 0 - 0 - - - - - 0 - 0 - - - - - 9 - - - 11 - - - 11 - - - 11 - - - 11 - - - - - - 0 - 0 - + + + + 0 + 0 + 428 + 484 + - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - + - - + + + + 0 + 0 + + + + + 0 + 0 @@ -2478,64 +2341,269 @@ border-radius:10px; 0 - - - - 15 - true - - - - Encoder Command - + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 15 + true + + + + Encoder Settings + + + + + + + + + + + 0 + 50 + + + + Qt::ScrollBarPolicy::ScrollBarAlwaysOff + + + Qt::ScrollBarPolicy::ScrollBarAsNeeded + + + true + + + + + 0 + 0 + 918 + 48 + + + + + + + + 900 + 30 + + + + + + + + + + + + + 25 + 25 + + + + <html><head/><body><p><span style=" font-weight:700;">Encoder command:</span></p><p><span style=" font-weight:700;"> -</span> Encoder settings passed to FFMpeg, can be manually tweaked.</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + + + + + + - - + + - - + + + + 12 + + + + Video Encoder + + + + + + + Qt::Orientation::Horizontal + + - 0 - 50 + 40 + 20 + + + + + + + + + 25 + 25 - - Qt::ScrollBarPolicy::ScrollBarAlwaysOff + + <html><head/><body><p><span style=" font-weight:700;">Encoder:</span></p><p>- Compression algorithm for video, most common is libx264</p><p>- Any encoder without a device label is a CPU encoder.</p><p>- <span style=" font-weight:700; text-decoration: underline;">Lossy Encoders:</span></p><p>- libx264 (<span style=" font-weight:700;">mkv, mp4, mov, avi</span>)</p><p>- libx265 (<span style=" font-weight:700;">mkv, mp4, mov</span>)</p><p>- vp9 (<span style=" font-weight:700;">mkv, webm</span>)</p><p>- av1 (<span style=" font-weight:700;">mkv, mp4, webm</span>)</p><p>- <span style=" font-weight:700; text-decoration: underline;">Lossless Encoders:</span></p><p>- prores (<span style=" font-weight:700;">mkv, mov</span>)</p><p>- ffv1 (<span style=" font-weight:700;">mkv, avi</span>)</p></body></html> + + + - - Qt::ScrollBarPolicy::ScrollBarAsNeeded + + :/icons/icons/info.svg - + true - - - - 0 - 0 - 918 - 48 - + + + + + + + 0 + 0 + + + + + libx264 - - - - - - 900 - 30 - - - - - - + + + + libx265 + + + + + vp9 + + + + + av1 + + + + + prores + + + + + ffv1 + + + + + x264_vulkan (experimental) + + + + + x264_nvenc + + + + + x265_nvenc + + + + + av1_nvenc (40 series and up) + + + + + x264_vaapi + + + + + x265_vaapi + + + + + av1_vaapi + + + + + + + + + + + + + + + 12 + + + + Video Encoder Speed + - + + + Qt::Orientation::Horizontal + + + + 237 + 20 + + + + + + 25 @@ -2543,7 +2611,7 @@ border-radius:10px; - <html><head/><body><p><span style=" font-weight:700;">Encoder command:</span></p><p><span style=" font-weight:700;"> -</span> Encoder settings passed to FFMpeg, can be manually tweaked.</p></body></html> + <html><head/><body><p><span style=" font-weight:700;">Encoder Speed:</span></p><p> - Speed at which the encoder processes the video, <span style=" font-weight:700; text-decoration: underline;">Slower</span> processing time means <span style=" font-weight:700; text-decoration: underline;">Better results for the file size</span><span style=" text-decoration: underline;">.</span></p></body></html> @@ -2556,327 +2624,444 @@ border-radius:10px; + + + + + 0 + 0 + + + + + fastest + + + + + fast + + + + + medium + + + + + slow + + + + + placebo + + + + - - - - - - - - - - - 0 - 0 - - - - - 6 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - 15 - true - - - - General Settings - - - - - - - - 6 - - - 9 - - - 9 - - - 9 - - - 9 - - - - - 12 - - - - - - - - - - Benchmark Mode - + + + + + + + 12 + + + + Video Quality + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + 0 + 0 + + + + + Very_High + + + + + High + + + + + Medium + + + + + Low + + + + + - - - Qt::Orientation::Horizontal - - - - 40 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p>Perform processing without outputing new video. This tests the raw performance of the inference.</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - + + + + + + + 12 + + + + Video Container + + + + + + + Qt::Orientation::Horizontal + + + + 437 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p><span style=" font-weight:700;">Video Container</span></p><p>- Changes the video container in the default output file generated container.</p><p>- <span style=" font-weight:700;">mkv </span>is <span style=" font-weight:700;">recomended</span> due to its extensive support for different formats.</p><p><span style=" font-weight:700;">WARNING: Changing this </span><span style=" font-weight:700; text-decoration: underline;">MAY</span><span style=" font-weight:700;"> break encoder compadibility. </span></p><p><span style=" font-weight:700;">Please look at the audio and video encoder tooltips for compadibility.</span></p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + mkv + + + + + mp4 + + + + + mov + + + + + webm + + + + + avi + + + + + - - - - - - - + + + + + + + 12 + + + + Video Pixel Format + + + + + + + Qt::Orientation::Horizontal + + + + 469 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p><span style=" font-weight:700;">Pixel Format</span></p><p>- D<span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">efines how color information is stored for each pixel in a video frame.</span></p><p><span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">- yuv420p: Most common, removes the most information. </span>(<span style=" font-weight:700;">mkv, mp4, mov, avi, webm</span>)</p><p><span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">- yuv422p: In between 420p and 444p. </span>(<span style=" font-weight:700;">mkv, mov, mp4</span>)</p><p><span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">- yuv444p: Keeps the most information. </span>(<span style=" font-weight:700;">mkv, mov, mp4</span>)<br/></p><p>- <span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">yuv 420/422/444 p10le: 10 bit version of each. (MKV and MP4 containers only) </span>(<span style=" font-weight:700;">mkv, mov, mp4</span>)</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + 0 + 0 + + + + + yuv420p + + + + + yuv422p + + + + + yuv444p + + + + + yuv420p (10 bit) + + + + + yuv422p (10 bit) + + + + + yuv444p (10 bit) + + + + + - - - - - - - - - - - 0 - 0 - - - - - 351 - 128 - - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - 15 - true - - - - Upscale/Restoration Tile Settings - - - - - - - - 6 - - - 9 - - - 0 - - - 9 - - - 9 - - - - - 12 - - - - Tiling - - - - - - - Qt::Orientation::Horizontal - - - - 40 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p>Split up processing upscaled frames into chunks.</p><p>Lowers VRAM usage, but also slows down render. </p><p>Only use when render failes due to VRAM limits.</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - + + + + + + + 12 + + + + Audio Encoder + + + + + + + Qt::Orientation::Horizontal + + + + 382 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p><span style=" font-weight:700;">Audio Encoder</span></p><p>- <span style=" font-weight:700;">Recommended </span>to use <span style=" font-weight:700;">aac</span> if the video container is <span style=" font-weight:700;">not mkv, </span>or if the output video has audio isses.</p><p><span style=" font-weight:700;">- </span>copy_audio: No re-encoding is done on the output</p><p>- aac: Most common, and high quality.</p><p>- libmp3lame: Used less, low quality.</p><p>- opus: Used for <span style=" font-weight:700;">webm</span>.</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + 0 + 0 + + + + + copy_audio + + + + + aac + + + + + libmp3lame + + + + + opus + + + + + - - - - - - - - - - - - - - 6 - - - 9 - - - 0 - - - 9 - - - 0 - - - - - - 12 - - - - Tile Size - - - - - - - Qt::Orientation::Horizontal - - - - 40 - 20 - - - - - - - - - 512 - - - - - 384 - - - - - 256 - - - - - 128 - - - - - 64 - - + + + + + + + 12 + + + + Audio Bitrate + + + + + + + Qt::Orientation::Horizontal + + + + 494 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p><span style=" font-weight:700;">Audio Bitrate</span></p><p><span style=" font-weight:700;">- </span>Higher bitrate = higher quality</p><p> - Only applies to <span style=" font-weight:700;">aac </span>and <span style=" font-weight:700;">libmp3lame</span>.</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + 0 + 0 + + + + + 320k + + + + + 192k + + + + + 128k + + + + + 96k + + + + + @@ -2884,18 +3069,229 @@ border-radius:10px; - + + + + + + + + + + + 0 + 0 + + + + + 16777215 + 16777215 + + + + Render Queue + + + + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 0 + 0 + + + + + 0 + 0 + + + + + 335 + 16777215 + + + + QListWidget{ +background-color:#1f232a; +border-radius:10px; +} + + + true + + + + + + + + + + Remove + + + + + + + Move Up + + + + + + + Move Down + + + + + + + + + + + + + 0 + + + + + + 0 + 0 + + + + + 0 + 0 + + + + + 16777215 + 16777215 + + + + QListWidget{ +background-color:#1f232a; +border-radius:10px; +} + + + + + + + Finished Output Files: + + + + + + + + + + + 0 + 0 + + + + + 16777215 + 16777215 + + + + Advanced + + + + + + + 0 + 0 + + + + true + + + Qt::AlignmentFlag::AlignLeading|Qt::AlignmentFlag::AlignLeft|Qt::AlignmentFlag::AlignTop + + + + + 0 + 0 + 393 + 430 + + + + + 0 + 0 + + + + + 0 + 0 + + + + + 9 + + + 11 + + + 11 + + + 11 + + + 11 + - + 0 0 - + - 0 + 6 0 @@ -2910,7 +3306,7 @@ border-radius:10px; 0 - + 15 @@ -2918,40 +3314,61 @@ border-radius:10px; - Interpolate Settings + General Settings - - + + + + 6 + + + 9 + + + 9 + + + 9 + + + 9 + - + 12 + + + + + + - Dynamic Scaled Flow (Pytorch Only) + Benchmark Mode - + Qt::Orientation::Horizontal - 226 + 40 20 - + 25 @@ -2959,7 +3376,7 @@ border-radius:10px; - <html><head/><body><p>Scales optical flow based on the difference between frames.</p><p>Helps with anime interpolation.</p></body></html> + <html><head/><body><p>Perform processing without outputing new video. This tests the raw performance of the inference.</p></body></html> @@ -2973,7 +3390,10 @@ border-radius:10px; - + + + + @@ -2982,36 +3402,97 @@ border-radius:10px; + + + + + + + + 0 + 0 + + + + + 351 + 128 + + + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 15 + true + + + + Upscale/Restoration Tile Settings + + + - - + + + + 6 + + + 9 + + + 0 + + + 9 + + + 9 + - + 12 - Ensemble (Pytorch/TensorRT Only) + Tiling - + Qt::Orientation::Horizontal - 226 + 40 20 - + 25 @@ -3019,7 +3500,7 @@ border-radius:10px; - <html><head/><body><p>Ensembles flow, can produce better results.</p><p>Only compadible with older RIFE models and GMFSS</p></body></html> + <html><head/><body><p>Split up processing upscaled frames into chunks.</p><p>Lowers VRAM usage, but also slows down render. </p><p>Only use when render failes due to VRAM limits.</p></body></html> @@ -3033,7 +3514,7 @@ border-radius:10px; - + @@ -3043,60 +3524,75 @@ border-radius:10px; - - + + + + 6 + + + 9 + + + 0 + + + 9 + + + 0 + - + 12 - SloMo Mode + Tile Size - + Qt::Orientation::Horizontal - 226 + 40 20 - - - - 25 - 25 - - - - <html><head/><body><p>Increase video length instead of framerate when interpolating.</p><p>Audio and Subtitles will not be transfered to output video.</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - + + + + 512 + + + + + 384 + + + + + 256 + + + + + 128 + + + + + 64 + + @@ -3105,79 +3601,304 @@ border-radius:10px; - - - - Qt::Orientation::Vertical - - - - 20 - 40 - + + + + + 0 + 0 + - - - - - - - - - - - Logs - - - - - - - 0 - 0 - - - - - 13 - - - - *:disabled{ - - color:white; -} - - - Qt::ScrollBarPolicy::ScrollBarAlwaysOff - - - Qt::ScrollBarPolicy::ScrollBarAlwaysOff - - - true - - - <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> -<html><head><meta name="qrichtext" content="1" /><meta charset="utf-8" /><style type="text/css"> -p, li { white-space: pre-wrap; } -hr { height: 1px; border-width: 0; } -li.unchecked::marker { content: "\2610"; } -li.checked::marker { content: "\2612"; } -</style></head><body style=" font-family:'Sans Serif'; font-size:13pt; font-weight:400; font-style:normal;"> -<p style="-qt-paragraph-type:empty; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px; font-family:'Noto Sans';"><br /></p></body></html> - + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 15 + true + + + + Interpolate Settings + + + + + + + + + + + 12 + + + + Dynamic Scaled Flow (Pytorch Only) + + + + + + + Qt::Orientation::Horizontal + + + + 226 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p>Scales optical flow based on the difference between frames.</p><p>Helps with anime interpolation.</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + + + + + + + + + + + + + + 12 + + + + Ensemble (Pytorch/TensorRT Only) + + + + + + + Qt::Orientation::Horizontal + + + + 226 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p>Ensembles flow, can produce better results.</p><p>Only compadible with older RIFE models and GMFSS</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + + + + + + + + + + + + + + 12 + + + + SloMo Mode + + + + + + + Qt::Orientation::Horizontal + + + + 226 + 20 + + + + + + + + + 25 + 25 + + + + <html><head/><body><p>Increase video length instead of framerate when interpolating.</p><p>Audio and Subtitles will not be transfered to output video.</p></body></html> + + + + + + :/icons/icons/info.svg + + + true + + + + + + + + + + + + + + + + + + + + + + + + + Qt::Orientation::Vertical + + + + 20 + 40 + + + + + + - - - - - - - + + + Logs + + + + + + + 0 + 0 + + + + + 13 + + + + *:disabled{ + + color:white; +} + + + Qt::ScrollBarPolicy::ScrollBarAlwaysOff + + + Qt::ScrollBarPolicy::ScrollBarAlwaysOff + + + true + + + <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> +<html><head><meta name="qrichtext" content="1" /><meta charset="utf-8" /><style type="text/css"> +p, li { white-space: pre-wrap; } +hr { height: 1px; border-width: 0; } +li.unchecked::marker { content: "\2610"; } +li.checked::marker { content: "\2612"; } +</style></head><body style=" font-family:'Sans Serif'; font-size:13pt; font-weight:400; font-style:normal;"> +<p style="-qt-paragraph-type:empty; margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px; font-family:'Noto Sans';"><br /></p></body></html> + + + + + + + + + + + + @@ -3537,8 +4258,8 @@ QTabWidget::pane { /* The tab widget frame */ 0 0 - 470 - 488 + 863 + 652 @@ -3580,8 +4301,8 @@ QTabWidget::pane { /* The tab widget frame */ 0 0 - 480 - 173 + 805 + 535 @@ -3972,642 +4693,11 @@ QTabWidget::pane { /* The tab widget frame */ - - - - - - - - - - - - - - - Encoders Settings - - - - - - - - - - - - - 12 - - - - Video Encoder - - - - - - - Qt::Orientation::Horizontal - - - - 40 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p><span style=" font-weight:700;">Encoder:</span></p><p>- Compression algorithm for video, most common is libx264</p><p>- Any encoder without a device label is a CPU encoder.</p><p>- <span style=" font-weight:700; text-decoration: underline;">Lossy Encoders:</span></p><p>- libx264 (<span style=" font-weight:700;">mkv, mp4, mov, avi</span>)</p><p>- libx265 (<span style=" font-weight:700;">mkv, mp4, mov</span>)</p><p>- vp9 (<span style=" font-weight:700;">mkv, webm</span>)</p><p>- av1 (<span style=" font-weight:700;">mkv, mp4, webm</span>)</p><p>- <span style=" font-weight:700; text-decoration: underline;">Lossless Encoders:</span></p><p>- prores (<span style=" font-weight:700;">mkv, mov</span>)</p><p>- ffv1 (<span style=" font-weight:700;">mkv, avi</span>)</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - 0 - 0 - - - - - libx264 - - - - - libx265 - - - - - vp9 - - - - - av1 - - - - - prores - - - - - ffv1 - - - - - x264_vulkan (experimental) - - - - - x264_nvenc - - - - - x265_nvenc - - - - - av1_nvenc (40 series and up) - - - - - x264_vaapi - - - - - x265_vaapi - - - - - av1_vaapi - - - - - - - - - - - - - - - 12 - - - - Video Encoder Speed - - - - - - - Qt::Orientation::Horizontal - - - - 237 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p><span style=" font-weight:700;">Encoder Speed:</span></p><p> - Speed at which the encoder processes the video, <span style=" font-weight:700; text-decoration: underline;">Slower</span> processing time means <span style=" font-weight:700; text-decoration: underline;">Better results for the file size</span><span style=" text-decoration: underline;">.</span></p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - 0 - 0 - - - - - fastest - - - - - fast - - - - - medium - - - - - slow - - - - - placebo - - - - - - - - - - - - - - - 12 - - - - Video Quality - - - - - - - Qt::Orientation::Horizontal - - - - 40 - 20 - - - - - - - - - 0 - 0 - - - - - Very_High - - - - - High - - - - - Medium - - - - - Low - - - - - - - - - - - - - - - 12 - - - - Video Container - - - - - - - Qt::Orientation::Horizontal - - - - 437 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p><span style=" font-weight:700;">Video Container</span></p><p>- Changes the video container in the default output file generated container.</p><p>- <span style=" font-weight:700;">mkv </span>is <span style=" font-weight:700;">recomended</span> due to its extensive support for different formats.</p><p><span style=" font-weight:700;">WARNING: Changing this </span><span style=" font-weight:700; text-decoration: underline;">MAY</span><span style=" font-weight:700;"> break encoder compadibility. </span></p><p><span style=" font-weight:700;">Please look at the audio and video encoder tooltips for compadibility.</span></p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - mkv - - - - - mp4 - - - - - mov - - - - - webm - - - - - avi - - - - - - - - - - - - - - - 12 - - - - Video Pixel Format - - - - - - - Qt::Orientation::Horizontal - - - - 469 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p><span style=" font-weight:700;">Pixel Format</span></p><p>- D<span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">efines how color information is stored for each pixel in a video frame.</span></p><p><span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">- yuv420p: Most common, removes the most information. </span>(<span style=" font-weight:700;">mkv, mp4, mov, avi, webm</span>)</p><p><span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">- yuv422p: In between 420p and 444p. </span>(<span style=" font-weight:700;">mkv, mov, mp4</span>)</p><p><span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">- yuv444p: Keeps the most information. </span>(<span style=" font-weight:700;">mkv, mov, mp4</span>)<br/></p><p>- <span style=" font-family:'Droid Sans Mono','monospace','monospace'; font-size:9pt; color:#cccccc;">yuv 420/422/444 p10le: 10 bit version of each. (MKV and MP4 containers only) </span>(<span style=" font-weight:700;">mkv, mov, mp4</span>)</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - 0 - 0 - - - - - yuv420p - - - - - yuv422p - - - - - yuv444p - - - - - yuv420p10le - - - - - yuv422p10le - - - - - yuv444p10le - - - - - - - - - - - - - - - 12 - - - - Audio Encoder - - - - - - - Qt::Orientation::Horizontal - - - - 382 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p><span style=" font-weight:700;">Audio Encoder</span></p><p>- <span style=" font-weight:700;">Recommended </span>to use <span style=" font-weight:700;">aac</span> if the video container is <span style=" font-weight:700;">not mkv, </span>or if the output video has audio isses.</p><p><span style=" font-weight:700;">- </span>copy_audio: No re-encoding is done on the output</p><p>- aac: Most common, and high quality.</p><p>- libmp3lame: Used less, low quality.</p><p>- opus: Used for <span style=" font-weight:700;">webm</span>.</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - 0 - 0 - - - - - copy_audio - - - - - aac - - - - - libmp3lame - - - - - opus - - - - - - - - - - - - - - - 12 - - - - Audio Bitrate - - - - - - - Qt::Orientation::Horizontal - - - - 494 - 20 - - - - - - - - - 25 - 25 - - - - <html><head/><body><p><span style=" font-weight:700;">Audio Bitrate</span></p><p><span style=" font-weight:700;">- </span>Higher bitrate = higher quality</p><p> - Only applies to <span style=" font-weight:700;">aac </span>and <span style=" font-weight:700;">libmp3lame</span>.</p></body></html> - - - - - - :/icons/icons/info.svg - - - true - - - - - - - - 0 - 0 - - - - - 320k - - - - - 192k - - - - - 128k - - - - - 96k - - - - - + + + + + @@ -5313,136 +5403,88 @@ background-color:white; - - - - - 0 - - - + + + + + 0 + 0 + + + + + - 25 - true + 12 - Models + Import Custom Upscale Model - - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - 0 - 0 - - - - - 50 - 45 - - - - - - - - :/icons/icons/download.svg:/icons/icons/download.svg - - - - 30 - 30 - - - - - - - - All Models for all Backends - - - - - + + + + Qt::Orientation::Horizontal + + + QSizePolicy::Policy::MinimumExpanding + + + + 20 + 20 + + + - - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - 0 - 0 - - - - - 50 - 45 - - - - - - - - :/icons/icons/download.svg:/icons/icons/download.svg - - - - 30 - 30 - - - - - - - - All models for Installed Backends - - - - - + + + + + + true + + + Select Model (NCNN/.bin+.param) + + + + + + + true + + + Select Model (PyTorch/TensorRT/.pth) + + + + + + + + + 9 + + + 9 + + + 9 + + + 9 + + + + @@ -5520,7 +5562,7 @@ background-color:white; - <html><head/><body><p><span style=" font-weight:700;">PyTorch Version</span></p><p>- This is the version downloaded</p><p>- 2.6: Supports older CUDA enabled GPUs</p><p>- 2.7: Supports all RTX cards, faster and recomended.</p></body></html> + <html><head/><body><p><span style=" font-weight:700;">PyTorch Version</span></p><p>- This is the version downloaded</p><p>- 2.6: Supports older CUDA enabled GPUs</p><p>- 2.8: Supports all RTX cards, faster and recomended.</p><p>- 2.9: Experimental.</p></body></html> 0 @@ -5640,8 +5682,8 @@ color: red; 0 0 - 791 - 162 + 845 + 333 @@ -6002,88 +6044,6 @@ color: red; - - - - - 0 - 0 - - - - - - - - 12 - - - - Import Custom Upscale Model - - - - - - - Qt::Orientation::Horizontal - - - QSizePolicy::Policy::MinimumExpanding - - - - 20 - 20 - - - - - - - - - - true - - - Select Model (NCNN/.bin+.param) - - - - - - - true - - - Select Model (PyTorch/TensorRT/.pth) - - - - - - - - - - - - - 9 - - - 9 - - - 9 - - - 9 - - - -