Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit a328494

Browse files
committed
update
Signed-off-by: minmingzhu <minming.zhu@intel.com>
1 parent 0ec9205 commit a328494

File tree

2 files changed

+35
-33
lines changed

2 files changed

+35
-33
lines changed

llm_on_ray/finetune/finetune_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class FinetuneConfig(BaseModel):
166166
Dataset: Dataset
167167
Training: Training
168168

169+
169170
base_models: Dict[str, FinetuneConfig] = {}
170171
_models: Dict[str, FinetuneConfig] = {}
171172

@@ -177,6 +178,6 @@ class FinetuneConfig(BaseModel):
177178
continue
178179
with open(file_path, "r") as f:
179180
m: FinetuneConfig = parse_yaml_raw_as(FinetuneConfig, f)
180-
_models[m.name] = m
181+
_models[m.General.base_model] = m
181182

182183
all_models = _models.copy()

llm_on_ray/ui/start_ui.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,20 @@ def get_result(self):
110110

111111
class ChatBotUI:
112112
def __init__(
113-
self,
114-
all_models: Dict[str, InferenceConfig],
115-
base_models: Dict[str, FinetuneConfig],
116-
finetune_model_path: str,
117-
finetuned_checkpoint_path: str,
118-
repo_code_path: str,
119-
default_data_path: str,
120-
default_rag_path: str,
121-
config: dict,
122-
head_node_ip: str,
123-
node_port: str,
124-
node_user_name: str,
125-
conda_env_name: str,
126-
master_ip_port: str,
113+
self,
114+
all_models: Dict[str, InferenceConfig],
115+
base_models: Dict[str, FinetuneConfig],
116+
finetune_model_path: str,
117+
finetuned_checkpoint_path: str,
118+
repo_code_path: str,
119+
default_data_path: str,
120+
default_rag_path: str,
121+
config: dict,
122+
head_node_ip: str,
123+
node_port: str,
124+
node_user_name: str,
125+
conda_env_name: str,
126+
master_ip_port: str,
127127
):
128128
self._all_models = all_models
129129
self._base_models = base_models
@@ -556,14 +556,15 @@ def finetune(
556556
finetune_config = self._base_models[model_name]
557557
gpt_base_model = finetune_config.General.gpt_base_model
558558

559-
560559
finetune_config = finetune_config.dict()
561560
last_gpt_base_model = False
562561
finetuned_model_path = os.path.join(self.finetuned_model_path, model_name, new_model_name)
563562

564563
exist_worker = int(finetune_config["Training"].get("num_training_workers"))
565564

566-
exist_cpus_per_worker_ftn = int(finetune_config["Training"].get("resources_per_worker")["CPU"])
565+
exist_cpus_per_worker_ftn = int(
566+
finetune_config["Training"].get("resources_per_worker")["CPU"]
567+
)
567568

568569
ray_resources = ray.available_resources()
569570
if "CPU" not in ray_resources or cpus_per_worker_ftn * worker_num + 1 > int(
@@ -602,9 +603,9 @@ def finetune(
602603

603604
finetune_config["Dataset"]["train_file"] = dataset
604605
if origin_model_path is not None:
605-
finetune_config["General"]["base_model"] = origin_model_path
606+
finetune_config["General"]["base_model"] = origin_model_path
606607
if tokenizer_path is not None:
607-
finetune_config["General"]["tokenizer_name"] = tokenizer_path
608+
finetune_config["General"]["tokenizer_name"] = tokenizer_path
608609
finetune_config["Training"]["epochs"] = num_epochs
609610
finetune_config["General"]["output_dir"] = finetuned_model_path
610611

@@ -698,30 +699,30 @@ def finetune_progress(self, progress=gr.Progress()):
698699
progress(
699700
float(int(value_step) / int(total_steps)),
700701
desc="Start Training: epoch "
701-
+ str(value_epoch)
702-
+ " / "
703-
+ str(total_epochs)
704-
+ " "
705-
+ "step "
706-
+ str(value_step)
707-
+ " / "
708-
+ str(total_steps),
702+
+ str(value_epoch)
703+
+ " / "
704+
+ str(total_epochs)
705+
+ " "
706+
+ "step "
707+
+ str(value_step)
708+
+ " / "
709+
+ str(total_steps),
709710
)
710711
except Exception:
711712
pass
712713
self.finetune_status = False
713714
return "<h4 style='text-align: left; margin-bottom: 1rem'>Completed the fine-tuning process.</h4>"
714715

715716
def deploy_func(
716-
self,
717-
model_name: str,
718-
replica_num: int,
719-
cpus_per_worker_deploy: int,
720-
hpus_per_worker_deploy: int,
717+
self,
718+
model_name: str,
719+
replica_num: int,
720+
cpus_per_worker_deploy: int,
721+
hpus_per_worker_deploy: int,
721722
):
722723
self.shutdown_deploy()
723724
if cpus_per_worker_deploy * replica_num > int(
724-
ray.available_resources()["CPU"]
725+
ray.available_resources()["CPU"]
725726
) or hpus_per_worker_deploy * replica_num > int(
726727
ray.available_resources()["HPU"] if "HPU" in ray.available_resources() else 0
727728
):

0 commit comments

Comments
 (0)