diff --git a/fastvideo/data_preprocess/preprocess.py b/fastvideo/data_preprocess/preprocess.py new file mode 100644 index 000000000..b0dd70222 --- /dev/null +++ b/fastvideo/data_preprocess/preprocess.py @@ -0,0 +1,118 @@ +import argparse +import json +import os + +import torch +import torch.distributed as dist + +from fastvideo.v1.logger import init_logger +from fastvideo.v1.utils import maybe_download_model, shallow_asdict +from fastvideo.v1.distributed import init_distributed_environment, initialize_model_parallel +from fastvideo.v1.fastvideo_args import FastVideoArgs +from fastvideo.v1.configs.models.vaes import WanVAEConfig +from fastvideo import PipelineConfig +from fastvideo.v1.pipelines.preprocess_pipeline import PreprocessPipeline + +logger = init_logger(__name__) + +def main(args): + args.model_path = maybe_download_model(args.model_path) + # Assume using torchrun + local_rank = int(os.getenv("RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + init_distributed_environment(world_size=world_size, rank=rank, local_rank=local_rank) + initialize_model_parallel(tensor_model_parallel_size=world_size, sequence_model_parallel_size=world_size) + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank) + + pipeline_config = PipelineConfig.from_pretrained(args.model_path) + kwargs = { + "use_cpu_offload": False, + "vae_precision": "fp32", + "vae_config": WanVAEConfig(load_encoder=True, load_decoder=False), + } + pipeline_config_args = shallow_asdict(pipeline_config) + pipeline_config_args.update(kwargs) + fastvideo_args = FastVideoArgs(model_path=args.model_path, + num_gpus=world_size, + device_str="cuda", + **pipeline_config_args, + ) + fastvideo_args.check_fastvideo_args() + fastvideo_args.device = torch.device(f"cuda:{local_rank}") + + pipeline = PreprocessPipeline(args.model_path, fastvideo_args) + pipeline.forward(batch=None, fastvideo_args=fastvideo_args, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--model_type", type=str, default="mochi") + parser.add_argument("--data_merge_path", type=str, required=True) + parser.add_argument("--validation_prompt_txt", type=str) + parser.add_argument("--num_frames", type=int, default=163) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--preprocess_video_batch_size", + type=int, + default=2, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--preprocess_text_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--samples_per_file", + type=int, + default=64 + ) + parser.add_argument( + "--flush_frequency", + type=int, + default=256, + help="how often to save to parquet files" + ) + parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.") + parser.add_argument("--max_height", type=int, default=480) + parser.add_argument("--max_width", type=int, default=848) + parser.add_argument("--video_length_tolerance_range", type=int, default=2.0) + parser.add_argument("--group_frame", action="store_true") # TODO + parser.add_argument("--group_resolution", action="store_true") # TODO + parser.add_argument("--dataset", default="t2v") + parser.add_argument("--train_fps", type=int, default=30) + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--text_max_length", type=int, default=256) + parser.add_argument("--speed_factor", type=float, default=1.0) + parser.add_argument("--drop_short_ratio", type=float, default=1.0) + # text encoder & vae & diffusion model + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument("--cfg", type=float, default=0.0) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + ) + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/fastvideo/v1/dataset/__init__.py b/fastvideo/v1/dataset/__init__.py index 19187237a..45dd65110 100644 --- a/fastvideo/v1/dataset/__init__.py +++ b/fastvideo/v1/dataset/__init__.py @@ -1,3 +1,5 @@ +import os + from torchvision import transforms from torchvision.transforms import Lambda from transformers import AutoTokenizer @@ -25,8 +27,8 @@ def getdataset(args, start_idx=0) -> T2V_dataset: *resize_topcrop, norm_fun, ]) - # tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir) - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, + tokenizer_path = os.path.join(args.model_path, "tokenizer") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, cache_dir=args.cache_dir) if args.dataset == "t2v": return T2V_dataset(args, diff --git a/fastvideo/v1/pipelines/preprocess_pipeline.py b/fastvideo/v1/pipelines/preprocess_pipeline.py new file mode 100644 index 000000000..77ac788d8 --- /dev/null +++ b/fastvideo/v1/pipelines/preprocess_pipeline.py @@ -0,0 +1,559 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +T2V Data Preprocessing pipeline implementation. + +This module contains an implementation of the T2V Data Preprocessing pipeline +using the modular pipeline architecture. +""" +import gc +import multiprocessing +import os +from concurrent.futures import ProcessPoolExecutor +from typing import Any, Dict + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from fastvideo.v1.dataset import getdataset +from fastvideo.v1.dataset.dataloader.schema import pyarrow_schema +from fastvideo.v1.fastvideo_args import FastVideoArgs +from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.v1.pipelines.stages import TextEncodingStage + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +class PreprocessPipeline(ComposedPipelineBase): + + _required_config_modules = ["text_encoder", "tokenizer", "vae"] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + @torch.no_grad() + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + args, + ): + # Initialize class variables for data sharing + self.video_data: Dict[str, Any] = {} # Store video metadata and paths + self.latent_data: Dict[str, Any] = {} # Store latent tensors + self.preprocess_validation_text(fastvideo_args, args) + self.preprocess_video_and_text(fastvideo_args, args) + + def preprocess_video_and_text(self, fastvideo_args: FastVideoArgs, args): + os.makedirs(args.output_dir, exist_ok=True) + # Create directory for combined data + combined_parquet_dir = os.path.join(args.output_dir, + "combined_parquet_dataset") + os.makedirs(combined_parquet_dir, exist_ok=True) + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + + # Get how many samples have already been processed + start_idx = 0 + for root, _, files in os.walk(combined_parquet_dir): + for file in files: + if file.endswith('.parquet'): + table = pq.read_table(os.path.join(root, file)) + start_idx += table.num_rows + + # Loading dataset + train_dataset = getdataset(args, start_idx=start_idx) + sampler = DistributedSampler(train_dataset, + rank=local_rank, + num_replicas=world_size, + shuffle=False) + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.preprocess_video_batch_size, + num_workers=args.dataloader_num_workers, + ) + + num_processed_samples = 0 + # Add progress bar for video preprocessing + pbar = tqdm(train_dataloader, + desc="Processing videos", + unit="batch", + disable=local_rank != 0) + for batch_idx, data in enumerate(pbar): + if data is None: + continue + + with torch.inference_mode(): + # Filter out invalid samples (those with all zeros) + valid_indices = [] + for i, pixel_values in enumerate(data["pixel_values"]): + if not torch.all( + pixel_values == 0): # Check if all values are zero + valid_indices.append(i) + num_processed_samples += len(valid_indices) + + if not valid_indices: + continue + + # Create new batch with only valid samples + valid_data = { + "pixel_values": + torch.stack( + [data["pixel_values"][i] for i in valid_indices]), + "text": [data["text"][i] for i in valid_indices], + "path": [data["path"][i] for i in valid_indices], + "fps": [data["fps"][i] for i in valid_indices], + "duration": [data["duration"][i] for i in valid_indices], + } + + # VAE + with torch.autocast("cuda", dtype=torch.float32): + latents = self.get_module("vae").encode( + valid_data["pixel_values"].to( + fastvideo_args.device)).mean + + batch_captions = valid_data["text"] + + batch = ForwardBatch( + data_type="video", + prompt=batch_captions, + prompt_embeds=[], + prompt_attention_mask=[], + ) + assert hasattr(self, "prompt_encoding_stage") + result_batch = self.prompt_encoding_stage(batch, fastvideo_args) + prompt_embeds, prompt_attention_mask = result_batch.prompt_embeds[ + 0], result_batch.prompt_attention_mask[0] + assert prompt_embeds.shape[0] == prompt_attention_mask.shape[0] + + # Get sequence lengths from attention masks (number of 1s) + seq_lens = prompt_attention_mask.sum(dim=1) + + non_padded_embeds = [] + non_padded_masks = [] + + # Process each item in the batch + for i in range(prompt_embeds.size(0)): + seq_len = seq_lens[i].item() + # Slice the embeddings and masks to keep only non-padding parts + non_padded_embeds.append(prompt_embeds[i, :seq_len]) + non_padded_masks.append(prompt_attention_mask[i, :seq_len]) + + # Update the tensors with non-padded versions + prompt_embeds = non_padded_embeds + prompt_attention_mask = non_padded_masks + + # Prepare batch data for Parquet dataset + batch_data = [] + + # Add progress bar for saving outputs + save_pbar = tqdm(enumerate(valid_data["path"]), + desc="Saving outputs", + unit="item", + leave=False) + for idx, video_path in save_pbar: + # Get the corresponding latent and info using video name + latent = latents[idx].cpu() + video_name = os.path.basename(video_path).split(".")[0] + height, width = valid_data["pixel_values"][idx].shape[-2:] + + # Convert tensors to numpy arrays + vae_latent = latent.cpu().numpy() + text_embedding = prompt_embeds[idx].cpu().numpy() + text_attention_mask = prompt_attention_mask[idx].cpu().numpy( + ).astype(np.uint8) + + # Create record for Parquet dataset + record = { + "id": video_name, + "vae_latent_bytes": vae_latent.tobytes(), + "vae_latent_shape": list(vae_latent.shape), + "vae_latent_dtype": str(vae_latent.dtype), + "text_embedding_bytes": text_embedding.tobytes(), + "text_embedding_shape": list(text_embedding.shape), + "text_embedding_dtype": str(text_embedding.dtype), + "text_attention_mask_bytes": text_attention_mask.tobytes(), + "text_attention_mask_shape": + list(text_attention_mask.shape), + "text_attention_mask_dtype": str(text_attention_mask.dtype), + "file_name": video_name, + "caption": valid_data["text"][idx], + "media_type": "video", + "width": width, + "height": height, + "num_frames": latents[idx].shape[1], + "duration_sec": float(valid_data["duration"][idx]), + "fps": float(valid_data["fps"][idx]), + } + batch_data.append(record) + + if batch_data: + # Add progress bar for writing to Parquet dataset + write_pbar = tqdm(total=1, + desc="Writing to Parquet dataset", + unit="batch") + # Convert batch data to PyArrow arrays + arrays = [ + pa.array([record["id"] for record in batch_data]), + pa.array( + [record["vae_latent_bytes"] for record in batch_data], + type=pa.binary()), + pa.array( + [record["vae_latent_shape"] for record in batch_data], + type=pa.list_(pa.int32())), + pa.array( + [record["vae_latent_dtype"] for record in batch_data]), + pa.array([ + record["text_embedding_bytes"] for record in batch_data + ], + type=pa.binary()), + pa.array([ + record["text_embedding_shape"] for record in batch_data + ], + type=pa.list_(pa.int32())), + pa.array([ + record["text_embedding_dtype"] for record in batch_data + ]), + pa.array([ + record["text_attention_mask_bytes"] + for record in batch_data + ], + type=pa.binary()), + pa.array([ + record["text_attention_mask_shape"] + for record in batch_data + ], + type=pa.list_(pa.int32())), + pa.array([ + record["text_attention_mask_dtype"] + for record in batch_data + ]), + pa.array([record["file_name"] for record in batch_data]), + pa.array([record["caption"] for record in batch_data]), + pa.array([record["media_type"] for record in batch_data]), + pa.array([record["width"] for record in batch_data], + type=pa.int32()), + pa.array([record["height"] for record in batch_data], + type=pa.int32()), + pa.array([record["num_frames"] for record in batch_data], + type=pa.int32()), + pa.array([record["duration_sec"] for record in batch_data], + type=pa.float32()), + pa.array([record["fps"] for record in batch_data], + type=pa.float32()), + ] + table = pa.Table.from_arrays( + arrays, names=[f.name for f in pyarrow_schema]) + write_pbar.update(1) + write_pbar.close() + + # Store the table in a list for later processing + if not hasattr(self, 'all_tables'): + self.all_tables = [] + self.all_tables.append(table) + + logger.info("Collected batch with %s samples", len(table)) + + if num_processed_samples >= args.flush_frequency: + assert hasattr(self, 'all_tables') and self.all_tables + print(f"Combining {len(self.all_tables)} batches...") + combined_table = pa.concat_tables(self.all_tables) + assert len(combined_table) == num_processed_samples + print(f"Total samples collected: {len(combined_table)}") + + # Calculate total number of chunks needed, discarding remainder + total_chunks = max( + num_processed_samples // args.samples_per_file, 1) + + print( + f"Fixed samples per parquet file: {args.samples_per_file}") + print(f"Total number of parquet files: {total_chunks}") + print( + f"Total samples to be processed: {total_chunks * args.samples_per_file} (discarding {num_processed_samples % args.samples_per_file} samples)" + ) + + # Split work among processes + num_workers = int(min(multiprocessing.cpu_count(), + total_chunks)) + chunks_per_worker = (total_chunks + num_workers - + 1) // num_workers + + print( + f"Using {num_workers} workers to process {total_chunks} chunks" + ) + logger.info("Chunks per worker: %s", chunks_per_worker) + + # Prepare work ranges + work_ranges = [] + for i in range(num_workers): + start_idx = i * chunks_per_worker + end_idx = min((i + 1) * chunks_per_worker, total_chunks) + if start_idx < total_chunks: + work_ranges.append( + (start_idx, end_idx, combined_table, i, + combined_parquet_dir, args.samples_per_file)) + + total_written = 0 + failed_ranges = [] + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = { + executor.submit(self.process_chunk_range, work_range): + work_range + for work_range in work_ranges + } + for future in tqdm(futures, desc="Processing chunks"): + try: + written = future.result() + total_written += written + logger.info("Processed chunk with %s samples", + written) + except Exception as e: + work_range = futures[future] + failed_ranges.append(work_range) + logger.error("Failed to process range %s-%s: %s", + work_range[0], work_range[1], str(e)) + + # Retry failed ranges sequentially + if failed_ranges: + logger.warning("Retrying %s failed ranges sequentially", + len(failed_ranges)) + for work_range in failed_ranges: + try: + total_written += self.process_chunk_range( + work_range) + except Exception as e: + logger.error( + "Failed to process range %s-%s after retry: %s", + work_range[0], work_range[1], str(e)) + + logger.info("Total samples written: %s", total_written) + + num_processed_samples = 0 + self.all_tables = [] + + def preprocess_validation_text(self, fastvideo_args: FastVideoArgs, args): + # Create Parquet dataset directory for validation + validation_parquet_dir = os.path.join(args.output_dir, + "validation_parquet_dataset") + os.makedirs(validation_parquet_dir, exist_ok=True) + + with open(args.validation_prompt_txt, encoding="utf-8") as file: + lines = file.readlines() + prompts = [line.strip() for line in lines] + + # Prepare batch data for Parquet dataset + batch_data = [] + + # Add progress bar for validation text preprocessing + pbar = tqdm(enumerate(prompts), + desc="Processing validation prompts", + unit="prompt") + for prompt_idx, prompt in pbar: + with torch.inference_mode(): + # Text Encoder + batch = ForwardBatch( + data_type="video", + prompt=prompt, + prompt_embeds=[], + prompt_attention_mask=[], + ) + assert hasattr(self, "prompt_encoding_stage") + result_batch = self.prompt_encoding_stage(batch, fastvideo_args) + prompt_embeds = result_batch.prompt_embeds[0] + prompt_attention_mask = result_batch.prompt_attention_mask[0] + + file_name = prompt.split(".")[0] + + # Get the sequence length from attention mask (number of 1s) + seq_len = prompt_attention_mask.sum().item() + + text_embedding = prompt_embeds[0, :seq_len].cpu().numpy() + text_attention_mask = prompt_attention_mask[ + 0, :seq_len].cpu().numpy().astype(np.uint8) + + # Log the shapes after removing padding + logger.info( + "Shape after removing padding - Embeddings: %s, Mask: %s", + text_embedding.shape, text_attention_mask.shape) + + # Create record for Parquet dataset + record = { + "id": file_name, + "vae_latent_bytes": b"", # Not available for validation + "vae_latent_shape": [], + "vae_latent_dtype": "", + "text_embedding_bytes": text_embedding.tobytes(), + "text_embedding_shape": list(text_embedding.shape), + "text_embedding_dtype": str(text_embedding.dtype), + "text_attention_mask_bytes": text_attention_mask.tobytes(), + "text_attention_mask_shape": list(text_attention_mask.shape), + "text_attention_mask_dtype": str(text_attention_mask.dtype), + "file_name": file_name, + "caption": prompt, + "media_type": "video", + "width": 0, # Not available for validation + "height": 0, # Not available for validation + "num_frames": 0, # Not available for validation + "duration_sec": 0.0, # Not available for validation + "fps": 0.0, # Not available for validation + } + batch_data.append(record) + + logger.info("Saved validation sample: %s", file_name) + + if batch_data: + # Add progress bar for writing to Parquet dataset + write_pbar = tqdm(total=1, + desc="Writing to Parquet dataset", + unit="batch") + # Convert batch data to PyArrow arrays + arrays = [ + pa.array([record["id"] for record in batch_data]), + pa.array([record["vae_latent_bytes"] for record in batch_data], + type=pa.binary()), + pa.array([record["vae_latent_shape"] for record in batch_data], + type=pa.list_(pa.int32())), + pa.array([record["vae_latent_dtype"] for record in batch_data]), + pa.array( + [record["text_embedding_bytes"] for record in batch_data], + type=pa.binary()), + pa.array( + [record["text_embedding_shape"] for record in batch_data], + type=pa.list_(pa.int32())), + pa.array( + [record["text_embedding_dtype"] for record in batch_data]), + pa.array([ + record["text_attention_mask_bytes"] for record in batch_data + ], + type=pa.binary()), + pa.array([ + record["text_attention_mask_shape"] for record in batch_data + ], + type=pa.list_(pa.int32())), + pa.array([ + record["text_attention_mask_dtype"] for record in batch_data + ]), + pa.array([record["file_name"] for record in batch_data]), + pa.array([record["caption"] for record in batch_data]), + pa.array([record["media_type"] for record in batch_data]), + pa.array([record["width"] for record in batch_data], + type=pa.int32()), + pa.array([record["height"] for record in batch_data], + type=pa.int32()), + pa.array([record["num_frames"] for record in batch_data], + type=pa.int32()), + pa.array([record["duration_sec"] for record in batch_data], + type=pa.float32()), + pa.array([record["fps"] for record in batch_data], + type=pa.float32()), + ] + table = pa.Table.from_arrays(arrays, + names=[f.name for f in pyarrow_schema]) + write_pbar.update(1) + write_pbar.close() + + logger.info("Total validation samples: %s", len(table)) + + work_range = (0, 1, table, 0, validation_parquet_dir, len(table)) + + total_written = 0 + failed_ranges = [] + with ProcessPoolExecutor(max_workers=1) as executor: + futures = { + executor.submit(self.process_chunk_range, work_range): + work_range + } + for future in tqdm(futures, desc="Processing chunks"): + try: + total_written += future.result() + except Exception as e: + work_range = futures[future] + failed_ranges.append(work_range) + logger.error("Failed to process range %s-%s: %s", + work_range[0], work_range[1], str(e)) + + if failed_ranges: + logger.warning("Retrying %s failed ranges sequentially", + len(failed_ranges)) + for work_range in failed_ranges: + try: + total_written += self.process_chunk_range(work_range) + except Exception as e: + logger.error( + "Failed to process range %s-%s after retry: %s", + work_range[0], work_range[1], str(e)) + + logger.info("Total validation samples written: %s", total_written) + + # Clear memory + del table + gc.collect() # Force garbage collection + + @staticmethod + def process_chunk_range(args: Any) -> int: + start_idx, end_idx, table, worker_id, output_dir, samples_per_file = args + try: + total_written = 0 + num_samples = len(table) + + # Create worker-specific subdirectory + worker_dir = os.path.join(output_dir, f"worker_{worker_id}") + os.makedirs(worker_dir, exist_ok=True) + + # Check how many files there are already in the dir, and update i accordingly + num_parquets = 0 + for root, _, files in os.walk(worker_dir): + for file in files: + if file.endswith('.parquet'): + num_parquets += 1 + + for i in range(start_idx, end_idx): + start_sample = i * samples_per_file + end_sample = min((i + 1) * samples_per_file, num_samples) + chunk = table.slice(start_sample, end_sample - start_sample) + + # Create chunk file in worker's directory + chunk_path = os.path.join( + worker_dir, f"data_chunk_{i + num_parquets}.parquet") + temp_path = chunk_path + '.tmp' + + try: + # Write to temporary file + pq.write_table(chunk, temp_path, compression='zstd') + + # Rename temporary file to final file + if os.path.exists(chunk_path): + os.remove( + chunk_path) # Remove existing file if it exists + os.rename(temp_path, chunk_path) + + total_written += len(chunk) + except Exception as e: + # Clean up temporary file if it exists + if os.path.exists(temp_path): + os.remove(temp_path) + raise e + + return total_written + except Exception as e: + logger.error("Error processing chunks %s-%s for worker %s: %s", + start_idx, end_idx, worker_id, str(e)) + raise + + +EntryClass = PreprocessPipeline \ No newline at end of file diff --git a/fastvideo/v1/pipelines/stages/text_encoding.py b/fastvideo/v1/pipelines/stages/text_encoding.py index 0e0af0d51..d3015c689 100644 --- a/fastvideo/v1/pipelines/stages/text_encoding.py +++ b/fastvideo/v1/pipelines/stages/text_encoding.py @@ -63,10 +63,15 @@ def forward( if fastvideo_args.use_cpu_offload: text_encoder = text_encoder.to(fastvideo_args.device) - assert isinstance(batch.prompt, str) - text = preprocess_func(batch.prompt) - text_inputs = tokenizer(text, **encoder_config.tokenizer_kwargs).to( - fastvideo_args.device) + assert isinstance(batch.prompt, (str, list)) + if isinstance(batch.prompt, str): + batch.prompt = [batch.prompt] + texts = [] + for prompt_str in batch.prompt: + texts.append(preprocess_func(prompt_str)) + text_inputs = tokenizer(texts, + **encoder_config.tokenizer_kwargs).to( + fastvideo_args.device) input_ids = text_inputs["input_ids"] attention_mask = text_inputs["attention_mask"] with set_forward_context(current_timestep=0, attn_metadata=None): @@ -78,6 +83,8 @@ def forward( prompt_embeds = postprocess_func(outputs) batch.prompt_embeds.append(prompt_embeds) + if batch.prompt_attention_mask is not None: + batch.prompt_attention_mask.append(attention_mask) if batch.do_classifier_free_guidance: assert isinstance(batch.negative_prompt, str) @@ -98,6 +105,9 @@ def forward( assert batch.negative_prompt_embeds is not None batch.negative_prompt_embeds.append(negative_prompt_embeds) + if batch.negative_attention_mask is not None: + batch.negative_attention_mask.append( + negative_attention_mask) if fastvideo_args.use_cpu_offload: text_encoder.to('cpu') diff --git a/scripts/preprocess/preprocess_wan_data.sh b/scripts/preprocess/preprocess_wan_data.sh new file mode 100644 index 000000000..5d03fc936 --- /dev/null +++ b/scripts/preprocess/preprocess_wan_data.sh @@ -0,0 +1,23 @@ +# export WANDB_MODE="offline" +GPU_NUM=1 # 2,4,8 +MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +MODEL_TYPE="wan" +DATA_MERGE_PATH="your/path/to/Mixkit-Src/merge.txt" +OUTPUT_DIR="your/path" +VALIDATION_PATH="assets/prompt.txt" + +torchrun --nproc_per_node=$GPU_NUM \ + fastvideo/data_preprocess/preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size=4 \ + --max_height=480 \ + --max_width=832 \ + --num_frames=81 \ + --dataloader_num_workers 1 \ + --output_dir=$OUTPUT_DIR \ + --model_type $MODEL_TYPE \ + --train_fps 16 \ + --validation_prompt_txt $VALIDATION_PATH \ + --samples_per_file 108 \ + --flush_frequency 108 \ No newline at end of file