-
Notifications
You must be signed in to change notification settings - Fork 178
Add gradio local inference demo #740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
b3a7ca3
9ce4794
199bb85
426f1fd
7117a81
f1ed7ef
7b72aa3
3bb3c06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,169 @@ | ||||||||||||
import argparse | ||||||||||||
import os | ||||||||||||
from copy import deepcopy | ||||||||||||
|
||||||||||||
import gradio as gr | ||||||||||||
import torch | ||||||||||||
|
||||||||||||
from fastvideo import VideoGenerator | ||||||||||||
from fastvideo.configs.sample.base import SamplingParam | ||||||||||||
|
||||||||||||
if __name__ == "__main__": | ||||||||||||
os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "VIDEO_SPARSE_ATTN" | ||||||||||||
parser = argparse.ArgumentParser(description="FastVideo Gradio Demo") | ||||||||||||
parser.add_argument("--model_path", | ||||||||||||
type=str, | ||||||||||||
default="FastVideo/FastWan2.1-T2V-1.3B-Diffusers", | ||||||||||||
help="Path to the model") | ||||||||||||
parser.add_argument("--num_gpus", | ||||||||||||
type=int, | ||||||||||||
default=1, | ||||||||||||
help="Number of GPUs to use") | ||||||||||||
parser.add_argument("--output_path", | ||||||||||||
type=str, | ||||||||||||
default="my_videos/", | ||||||||||||
help="Path to save generated videos") | ||||||||||||
parsed_args = parser.parse_args() | ||||||||||||
|
||||||||||||
|
||||||||||||
generator = VideoGenerator.from_pretrained( | ||||||||||||
model_path=parsed_args.model_path, num_gpus=parsed_args.num_gpus) | ||||||||||||
|
||||||||||||
default_params = SamplingParam.from_pretrained(parsed_args.model_path) | ||||||||||||
|
||||||||||||
def generate_video( | ||||||||||||
prompt, | ||||||||||||
negative_prompt, | ||||||||||||
use_negative_prompt, | ||||||||||||
seed, | ||||||||||||
guidance_scale, | ||||||||||||
num_frames, | ||||||||||||
height, | ||||||||||||
width, | ||||||||||||
num_inference_steps, | ||||||||||||
randomize_seed=False, | ||||||||||||
): | ||||||||||||
params = deepcopy(default_params) | ||||||||||||
params.prompt = prompt | ||||||||||||
params.negative_prompt = negative_prompt | ||||||||||||
params.seed = seed | ||||||||||||
params.guidance_scale = guidance_scale | ||||||||||||
params.num_frames = num_frames | ||||||||||||
params.height = height | ||||||||||||
params.width = width | ||||||||||||
params.num_inference_steps = num_inference_steps | ||||||||||||
|
||||||||||||
if randomize_seed: | ||||||||||||
params.seed = torch.randint(0, 1000000, (1, )).item() | ||||||||||||
|
||||||||||||
if use_negative_prompt and negative_prompt: | ||||||||||||
params.negative_prompt = negative_prompt | ||||||||||||
else: | ||||||||||||
params.negative_prompt = default_params.negative_prompt | ||||||||||||
|
||||||||||||
generator.generate_video(prompt=prompt, sampling_param=params,output_path=parsed_args.output_path,save_video=True) | ||||||||||||
|
||||||||||||
output_path = os.path.join(parsed_args.output_path, | ||||||||||||
f"{params.prompt[:100]}.mp4") | ||||||||||||
|
output_path = os.path.join(parsed_args.output_path, | |
f"{params.prompt[:100]}.mp4") | |
safe_prompt = "".join(c for c in params.prompt[:100] if c.isalnum() or c in (' ', '_')).rstrip().replace(' ', '_') | |
output_path = os.path.join(parsed_args.output_path, | |
f"{safe_prompt}.mp4") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
prompt
,output_path
, andsave_video
arguments are passed togenerator.generate_video
even though they are also part of thesampling_param
object. This is redundant and can be confusing. It would be cleaner to set all parameters in theparams
object and pass only that to the generator.