Text Segmentation Task ( with custom dataset ) #248
-
I want to train the BiRefNet for Text segmentation task. |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 15 replies
-
Yeah, your process is proper. You can choose training from scratch or fine-tuning with some pre-trained weights. Interesting about the performance you can achieve. |
Beta Was this translation helpful? Give feedback.
-
I am trying to fine-tune BiRefNet on my own custom dataset on my Personal windows System, Dataset folder structured like this:
Changes I mage in the config.py:self.sys_home_dir = r"C:\Users\user\Desktop\workspace"
self.task = 'Custom'
self.testsets = {'Custom': 'TE-Custom'}
self.training_set = {'Custom': 'TR-Custom'}
self.size = (512, 512)
self.compile = False
self.save_last = 50
self.save_step = 5
Changes I made in the train.py:os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:False'
pin_memory = False I could not run the .sh scripts on Windows, so I directly executed train.py using a terminal and passed the arguments through argparse. python train.py --ckpt_dir ckpt/SAVED --epochs 256 --dist False --resume C:\Users\user\Desktop\workspace\weights\cv\BiRefNet-general-resolution_512x512-fp16-epoch_216.pth These were the ONLY modification made in code base. Could you please help me understand what might be going wrong here? Do I need to use the task as Matting ? |
Beta Was this translation helpful? Give feedback.
-
Hi, @sau-arv-gul, how about your results now? |
Beta Was this translation helpful? Give feedback.
-
Hi ZhengPeng7, hope you’re doing well |
Beta Was this translation helpful? Give feedback.
Hi @ZhengPeng7 !
Thanks a lot for training the model and sharing the checkpoints — they worked great!
I figured out why I wasn’t getting the segmented images with my own training on Windows OS . In dataset.py (line 64), the replacement
p.replace('/im/', '/gt/') works on Linux, but not on Windows and due to which the label path remains the same as original image. so the replacement '/im/' → '/gt/' never happens for me on my Windows.
So my labels were actually just grayscale versions of the input images itself, not the real GT masks.
image = path_to_image(self.image_paths[index], size=self.data_size, color_type='rgb'…