From 18e5457b5ae517de026f2ff2c692f7d708f49925 Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Thu, 15 Aug 2024 16:02:45 +0800 Subject: [PATCH 01/13] add finetuning microservice. Signed-off-by: Xinyu Ye --- comps/finetuning/README.md | 48 ++ comps/finetuning/datasets/.gitkeep | 0 comps/finetuning/docker/Dockerfile.finetune | 22 + comps/finetuning/finetune_runner.py | 48 ++ comps/finetuning/finetuning.py | 90 ++++ comps/finetuning/handlers.py | 140 +++++ comps/finetuning/jobs/.gitkeep | 0 .../finetuning/llm_on_ray/common/__init__.py | 18 + comps/finetuning/llm_on_ray/common/common.py | 38 ++ comps/finetuning/llm_on_ray/common/logging.py | 67 +++ .../llm_on_ray/common/torch_config.py | 86 +++ .../llm_on_ray/finetune/__init__.py | 15 + .../llm_on_ray/finetune/data_process.py | 220 ++++++++ .../llm_on_ray/finetune/finetune.py | 492 ++++++++++++++++++ .../llm_on_ray/finetune/finetune_config.py | 167 ++++++ comps/finetuning/models.py | 53 ++ .../finetuning/models/llama-2-7b-chat-hf.yaml | 39 ++ comps/finetuning/models/mistral-7b-v0.1.yaml | 45 ++ comps/finetuning/requirements.txt | 3 + 19 files changed, 1591 insertions(+) create mode 100644 comps/finetuning/README.md create mode 100644 comps/finetuning/datasets/.gitkeep create mode 100644 comps/finetuning/docker/Dockerfile.finetune create mode 100644 comps/finetuning/finetune_runner.py create mode 100644 comps/finetuning/finetuning.py create mode 100644 comps/finetuning/handlers.py create mode 100644 comps/finetuning/jobs/.gitkeep create mode 100644 comps/finetuning/llm_on_ray/common/__init__.py create mode 100644 comps/finetuning/llm_on_ray/common/common.py create mode 100644 comps/finetuning/llm_on_ray/common/logging.py create mode 100644 comps/finetuning/llm_on_ray/common/torch_config.py create mode 100644 comps/finetuning/llm_on_ray/finetune/__init__.py create mode 100644 comps/finetuning/llm_on_ray/finetune/data_process.py create mode 100644 comps/finetuning/llm_on_ray/finetune/finetune.py create mode 100644 comps/finetuning/llm_on_ray/finetune/finetune_config.py create mode 100644 comps/finetuning/models.py create mode 100644 comps/finetuning/models/llama-2-7b-chat-hf.yaml create mode 100644 comps/finetuning/models/mistral-7b-v0.1.yaml create mode 100644 comps/finetuning/requirements.txt diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md new file mode 100644 index 000000000..3b3daf9ef --- /dev/null +++ b/comps/finetuning/README.md @@ -0,0 +1,48 @@ +# LLM Fine-tuning Microservice + +LLM Fine-tuning microservice involves adapting a base model to a specific task or dataset to improve its performance on that task. + +# 🚀1. Start Microservice with Python + +## 1.1 Install Requirements + +```bash +pip install -r requirements.txt +``` + +## 1.2 Start Finetuning Service with Python Script + +### 1.2.1 Start Ray Cluster + +TBD + +### 1.2.2 Start Finetuning Service + +```bash +export RAY_ADDRESS="ray://${ray_head_ip}:10001" +python finetuning/finetuning.py +``` + +# 🚀2. Consume Finetuning Service + +## 2.1 Check Service Status + +```bash +curl http://${your_ip}:8000/v1/health_check\ + -X GET \ + -H 'Content-Type: application/json' +``` + +## 2.2 Create fine-tuning job + +Assuming a training file `file-vGxE9KywnSUkEL6dv9qZxKAF.jsonl` is uploaded, the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: + +```bash +curl http://${your_ip}:8000/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "file-vGxE9KywnSUkEL6dv9qZxKAF.jsonl", + "model": "meta-llama/Llama-2-7b-chat-hf" + }' +``` diff --git a/comps/finetuning/datasets/.gitkeep b/comps/finetuning/datasets/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/comps/finetuning/docker/Dockerfile.finetune b/comps/finetuning/docker/Dockerfile.finetune new file mode 100644 index 000000000..30b9c6171 --- /dev/null +++ b/comps/finetuning/docker/Dockerfile.finetune @@ -0,0 +1,22 @@ +# Use the same python version with ray +FROM python:3.10.14 + +WORKDIR /root/opea-finetune + +RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \ + && apt-get install -y vim htop net-tools dnsutils \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# COPY ./install-llm-on-ray.sh /tmp/install-llm-on-ray.sh +# RUN --mount=type=cache,target=/root/.cache/pip /tmp/install-llm-on-ray.sh + +COPY ./ . + +RUN --mount=type=cache,target=/root/.cache/pip cd ./llm-on-ray && pip install -v -e .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ + +RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir --upgrade -r requirements.txt + +RUN echo 'source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh' >> ~/.bashrc + +CMD ["bash", "-c", "./run.sh"] \ No newline at end of file diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py new file mode 100644 index 000000000..646341d73 --- /dev/null +++ b/comps/finetuning/finetune_runner.py @@ -0,0 +1,48 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import time +import uuid +from typing import List + +from llm_on_ray.finetune.finetune_config import FinetuneConfig +from pydantic_yaml import parse_yaml_raw_as +from ray.train.base_trainer import TrainingFailedError +from ray.tune.callback import Callback +from ray.tune.experiment import Trial +from ray.tune.logger import LoggerCallback +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments + + +class FineTuneCallback(TrainerCallback): + def __init__(self) -> None: + super().__init__() + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + print("FineTuneCallback:", args, state) + + +def main(): + parser = argparse.ArgumentParser(description="Runner for llm_on_ray-finetune") + parser.add_argument("--config_file", type=str, required=True, default=None) + args = parser.parse_args() + model_config_file = args.config_file + + with open(model_config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f).model_dump() + + callback = FineTuneCallback() + finetune_config["Training"]["callbacks"] = [callback] + + from llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main + + llm_on_ray_finetune_main(finetune_config) + # try: + # llm_on_ray_finetune_main(finetune_config) + # except TrainingFailedError as e: + # print(e) + + +if __name__ == "__main__": + main() diff --git a/comps/finetuning/finetuning.py b/comps/finetuning/finetuning.py new file mode 100644 index 000000000..a02833c00 --- /dev/null +++ b/comps/finetuning/finetuning.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import uvicorn +from fastapi import BackgroundTasks, Cookie, FastAPI, Form, Header, Response +from handlers import ( + handle_cancel_finetuning_job, + handle_create_finetuning_jobs, + handle_list_finetuning_jobs, + handle_retrieve_finetuning_job, +) +from models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest +from pydantic import BaseModel + +app = FastAPI() + + +@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob) +def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): + return handle_create_finetuning_jobs(request, background_tasks) + # return { + # "object": "fine_tuning.job", + # "id": "ftjob-abc123", + # "model": "davinci-002", + # "created_at": 1692661014, + # "finished_at": 1692661190, + # "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + # "organization_id": "org-123", + # "result_files": ["file-abc123"], + # "status": "succeeded", + # "validation_file": None, + # "training_file": "file-abc123", + # "hyperparameters": { + # "n_epochs": 4, + # "batch_size": 1, + # "learning_rate_multiplier": 1.0, + # }, + # "trained_tokens": 5768, + # "integrations": [], + # "seed": 0, + # "estimated_finish": 0, + # } + + +@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList) +def list_finetuning_jobs(): + return handle_list_finetuning_jobs() + # return { + # "object": "list", + # "data": [ + # { + # "object": "fine_tuning.job", + # "id": "ftjob-abc123", + # "model": "davinci-002", + # "created_at": 1692661014, + # "finished_at": 1692661190, + # "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + # "organization_id": "org-123", + # "result_files": ["file-abc123"], + # "status": "succeeded", + # "training_file": "file-abc123", + # "hyperparameters": { + # "n_epochs": 4, + # "batch_size": 1, + # "learning_rate_multiplier": 1.0, + # }, + # "trained_tokens": 5768, + # "integrations": [], + # "seed": 0, + # "estimated_finish": 0, + # }, + # ], + # "has_more": True, + # } + + +@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob) +def retrieve_finetuning_job(fine_tuning_job_id): + job = handle_retrieve_finetuning_job(fine_tuning_job_id) + return job + + +@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob) +def cancel_finetuning_job(fine_tuning_job_id): + job = handle_cancel_finetuning_job(fine_tuning_job_id) + return job + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py new file mode 100644 index 000000000..a874369f9 --- /dev/null +++ b/comps/finetuning/handlers.py @@ -0,0 +1,140 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +import time +import uuid +from typing import Any, Dict, List, Set + +from fastapi import BackgroundTasks, HTTPException +from llm_on_ray.finetune.finetune import main +from llm_on_ray.finetune.finetune_config import FinetuneConfig +from models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest +from pydantic_yaml import parse_yaml_raw_as, to_yaml_file +from ray.job_submission import JobSubmissionClient +from ray.train.base_trainer import TrainingFailedError +from ray.tune.logger import LoggerCallback + +MODEL_CONFIG_FILE_MAP = { + "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", + "mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml", +} + +DATASET_BASE_PATH = "datasets" + +FineTuningJobID = str +CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs + +global ray_client +ray_client: JobSubmissionClient = None + +running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {} +finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {} + + +# Add a background task to periodicly update job status +def update_job_status(job_id: FineTuningJobID): + while True: + job_status = ray_client.get_job_status(finetuning_job_to_ray_job[job_id]) + status = str(job_status).lower() + # Ray status "stopped" is OpenAI status "cancelled" + status = "cancelled" if status == "stopped" else status + print(f"Status of job {job_id} is '{status}'") + running_finetuning_jobs[job_id].status = status + if status == "finished" or status == "cancelled" or status == "failed": + break + time.sleep(CHECK_JOB_STATUS_INTERVAL) + + +def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): + base_model = request.model + train_file = request.training_file + train_file_path = os.path.join(DATASET_BASE_PATH, train_file) + + model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model) + if not model_config_file: + raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!") + + if not os.path.exists(train_file_path): + raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!") + + with open(model_config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f) + + finetune_config.Dataset.train_file = train_file_path + + job = FineTuningJob( + id=f"ft-job-{uuid.uuid4()}", + model=base_model, + created_at=int(time.time()), + training_file=train_file, + hyperparameters={ + "n_epochs": finetune_config.Training.epochs, + "batch_size": finetune_config.Training.batch_size, + "learning_rate_multiplier": finetune_config.Training.learning_rate, + }, + status="running", + # TODO: Add seed in finetune config + seed=random.randint(0, 1000), + ) + + finetune_config_file = f"jobs/{job.id}.yaml" + to_yaml_file(finetune_config_file, finetune_config) + + global ray_client + ray_client = JobSubmissionClient() if ray_client is None else ray_client + + ray_job_id = ray_client.submit_job( + # Entrypoint shell command to execute + entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}", + # Path to the local directory that contains the script.py file + runtime_env={"working_dir": "./"}, + ) + print(f"Submitted Ray job: {ray_job_id} ...") + + running_finetuning_jobs[job.id] = job + finetuning_job_to_ray_job[job.id] = ray_job_id + + background_tasks.add_task(update_job_status, job.id) + + return job + + +def handle_list_finetuning_jobs(): + finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False) + + return finetuning_jobs_list + + +def handle_retrieve_finetuning_job(fine_tuning_job_id): + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + return job + + +def handle_cancel_finetuning_job(fine_tuning_job_id): + ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id) + if ray_job_id is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + + global ray_client + ray_client = JobSubmissionClient() if ray_client is None else ray_client + ray_client.stop_job(ray_job_id) + + job = running_finetuning_jobs.get(fine_tuning_job_id) + job.status = "cancelled" + return job + + +# def cancel_all_jobs(): +# global ray_client +# ray_client = JobSubmissionClient() if ray_client is None else ray_client +# # stop all jobs +# for job_id in finetuning_job_to_ray_job.values(): +# ray_client.stop_job(job_id) + +# for job_id in running_finetuning_jobs: +# running_finetuning_jobs[job_id].status = "cancelled" +# return running_finetuning_jobs diff --git a/comps/finetuning/jobs/.gitkeep b/comps/finetuning/jobs/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/comps/finetuning/llm_on_ray/common/__init__.py b/comps/finetuning/llm_on_ray/common/__init__.py new file mode 100644 index 000000000..a84df4482 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from llm_on_ray.common.logging import logger +from llm_on_ray.common.torch_config import TorchConfig diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py new file mode 100644 index 000000000..87f74096c --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/common.py @@ -0,0 +1,38 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import glob +import importlib +import os + +from llm_on_ray.common.logging import logger + + +def import_all_modules(basedir, prefix=None): + all_py_files = glob.glob(basedir + "/*.py") + modules = [os.path.basename(f) for f in all_py_files] + + for module in modules: + if not module.startswith("_"): + module = module.rstrip(".py") + if prefix is None: + module_name = module + else: + module_name = f"{prefix}.{module}" + try: + importlib.import_module(module_name) + except Exception: + logger.warning(f"import {module_name} error", exc_info=True) diff --git a/comps/finetuning/llm_on_ray/common/logging.py b/comps/finetuning/llm_on_ray/common/logging.py new file mode 100644 index 000000000..6d3f6ae80 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/logging.py @@ -0,0 +1,67 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import logging +import logging.config +import traceback + +__all__ = ["logger", "get_logger"] + +use_accelerate_log = False +logger_name = "common" + +logging_config = { + "version": 1, + "loggers": { + "root": {"level": "INFO", "handlers": ["consoleHandler"]}, + "common": { + "level": "INFO", + "handlers": ["consoleHandler"], + "qualname": "common", + "propagate": 0, + }, + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "INFO", + "formatter": "standardFormatter", + }, + }, + "formatters": { + "standardFormatter": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "datefmt": "", + } + }, +} + +if logging_config is not None: + try: + logging.config.dictConfig(logging_config) + except Exception: + traceback.print_exc() + exit(1) + +if use_accelerate_log: + import accelerate + + get_logger = functools.partial(accelerate.logging.get_logger, name=logger_name) +else: + get_logger = functools.partial(logging.getLogger, name=logger_name) + +logger = get_logger() diff --git a/comps/finetuning/llm_on_ray/common/torch_config.py b/comps/finetuning/llm_on_ray/common/torch_config.py new file mode 100644 index 000000000..115093ba3 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/torch_config.py @@ -0,0 +1,86 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ray.train.torch.config import _TorchBackend +from ray.train.torch.config import TorchConfig as RayTorchConfig +from ray.train._internal.worker_group import WorkerGroup +from dataclasses import dataclass +from typing import Optional +import os +import sys + +# The package importlib_metadata is in a different place, depending on the Python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +@dataclass +class TorchConfig(RayTorchConfig): + device: Optional[str] = None + + @property + def backend_cls(self): + EnableCCLBackend.device = self.device + return EnableCCLBackend + + +def xpu_libs_import(): + """try to import IPEX and oneCCL.""" + try: + import intel_extension_for_pytorch + except ImportError: + raise ImportError("Please install intel_extension_for_pytorch") + try: + ccl_version = importlib_metadata.version("oneccl_bind_pt") + if ccl_version >= "1.12": + import oneccl_bindings_for_pytorch + else: + import torch_ccl + except ImportError as ccl_not_exist: + raise ImportError("Please install torch-ccl") from ccl_not_exist + + +def hpu_libs_import(): + """try to import habana frameworkfs for torch""" + try: + import habana_frameworks.torch # noqa: F401 + except ImportError as habana_not_exist: + raise ImportError("Please install habana_frameworks") from habana_not_exist + + +def _set_torch_distributed_env_vars(device): + if device is not None: + os.environ["ACCELERATE_TORCH_DEVICE"] = device + + +class EnableCCLBackend(_TorchBackend): + device: Optional[str] = None + + def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + libs_import = ( + hpu_libs_import + if self.device is not None and self.device.startswith("hpu") + else xpu_libs_import + ) + for i in range(len(worker_group)): + worker_group.execute_single_async(i, libs_import) + super().on_start(worker_group, backend_config) + + def on_training_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + super().on_training_start(worker_group, backend_config) + worker_group.execute(_set_torch_distributed_env_vars, self.device) diff --git a/comps/finetuning/llm_on_ray/finetune/__init__.py b/comps/finetuning/llm_on_ray/finetune/__init__.py new file mode 100644 index 000000000..854e39ad4 --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py new file mode 100644 index 000000000..c617a4215 --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -0,0 +1,220 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import re +from itertools import chain + +import torch + +IGNORE_INDEX = -100 + + +class DataProcessor: + # We used the following prompts for fine-tuning the Alpaca model. You can find reference doc form this URL(https://github.com/tatsu-lab/stanford_alpaca/blob/main/README.md#data-release) + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.end = tokenizer.eos_token + self.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + self.instruction = "### Instruction:\n" + self.input = "### Input:\n" + self.response = "### Response:\n" + self.padding_side = config["Dataset"].get("padding_side", "right") + self.truncation_side = config["Dataset"].get("truncation_side", "right") + self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 512) + self.max_source_length = config["Dataset"].get("max_source_length", 384) + self.truncation = config["Dataset"].get("truncation", True) + self.padding = config["Dataset"].get("padding", True) + self.mask_input = config["Dataset"].get("mask_input", True) + self.mask_response = config["Dataset"].get("mask_response", True) + + def make_prompt(self, examples): + prompts = {} + prompts["prompt_sources"] = [] + prompts["prompt_targets"] = [] + for rec in examples: + instruction = rec["instruction"] + response = rec["input"] + context = rec.get("output") + if not instruction: + raise ValueError(f"Expected an instruction in: {rec}") + # if not response: + # raise ValueError(f"Expected a response in: {rec}") + if context: + prompt = ( + self.intro + + self.end + + "\n" + + self.instruction + + instruction + + self.input + + context + + self.end + + "\n" + + self.response + ) + prompts["prompt_sources"].append(prompt) + else: + prompt = ( + self.intro + + self.end + + "\n" + + self.instruction + + instruction + + self.end + + "\n" + + self.response + ) + prompts["prompt_sources"].append(prompt) + prompt_response = response + self.end + prompts["prompt_targets"].append(prompt_response) + return prompts + + def __truncate_sequences(self, sequences, max_length): + """ + Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40 + """ + words_to_cut = sum(list(map(len, sequences))) - max_length + if words_to_cut <= 0: + return sequences + + while words_to_cut > 0 and len(sequences) > 0: + words_to_cut -= len(sequences[0]) + sequences = sequences[1:] + return sequences + + def tokenize_by_neural_chat(self, examples): + """ + Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225 + The only differences are: + - using our own prompt style + - add left or right padding and truncation + - add mask_input and mask_response + """ + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + assistant_tokens = self.tokenizer.tokenize(self.response) + header = self.intro + self.end + "\n" + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for instruction, response in zip(examples[keys[0]], examples[keys[1]]): + convs = re.findall( + r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end), + instruction, + re.DOTALL, + ) + convs_tokens = [ + self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs + ] + header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n") + max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens) + truncated_convs = self.__truncate_sequences(convs_tokens, max_input) + if len(truncated_convs) == 0: + truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]] + + prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] + prompt_ids = [ + self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens + ] + prompt_ids = list(chain(*prompt_ids)) + + resp_ids = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(response.strip()) + ) + # keep last and eos_id + max_resp = self.max_seq_length - len(prompt_ids) - 1 + + # truncating response + if len(resp_ids) > max_resp: + if self.truncation_side == "right": + resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:] + else: + resp_ids = resp_ids[-max_resp:] + + # masking + input_ids = prompt_ids + resp_ids + [self.tokenizer.eos_token_id] + if self.mask_input: + labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [self.tokenizer.eos_token_id] + elif self.mask_response: + labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [self.tokenizer.eos_token_id] + else: + labels = input_ids + + # padding + input_len = len(input_ids) + pad_len = self.max_seq_length - input_len + if self.padding_side == "right": + input_ids = input_ids + [self.tokenizer.eos_token_id] * pad_len + labels = labels + [IGNORE_INDEX] * pad_len + attention_mask = [1] * input_len + [0] * pad_len + else: + input_ids = [self.tokenizer.eos_token_id] * pad_len + input_ids + labels = [IGNORE_INDEX] * pad_len + labels + attention_mask = [0] * pad_len + [1] * input_len + + assert len(input_ids) == self.max_seq_length + assert len(prompt_ids) <= self.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) + + examples["input_ids"].append(torch.tensor(input_ids)) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) + + return examples + + def tokenize(self, examples): + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for s, t in zip(examples[keys[0]], examples[keys[1]]): + results = self.tokenizer( + s + t, + padding=self.padding, + truncation=self.truncation, + return_tensors=None, + max_length=self.max_length, + ) + + input_ids = results["input_ids"] + input_len = len(input_ids) + labels = copy.deepcopy(input_ids) + if self.mask_input or self.mask_response: + sources_tokenized = self.tokenizer( + s, + padding=False, + truncation=True, + return_tensors=None, + max_length=self.max_length, + ) + input_id_len = len(sources_tokenized["input_ids"]) + # mask input + if self.mask_input: + labels[:input_id_len] = [IGNORE_INDEX] * input_id_len + # mask response + if self.mask_response: + labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len) + + examples["input_ids"].append(results["input_ids"]) + examples["labels"].append(labels) + examples["attention_mask"].append(results["attention_mask"]) + return examples diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py new file mode 100644 index 000000000..46714cdbb --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -0,0 +1,492 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#!/usr/bin/env python + +import os +import argparse +import re +import sys +import copy + +from typing import Any, Dict, Union, Optional + +from itertools import chain + +import torch +import datasets +import transformers + +from peft import get_peft_model, LoraConfig + +import ray +from ray.train.torch import TorchTrainer +from ray.air.config import ScalingConfig +from ray.air import RunConfig, FailureConfig + +from pydantic_yaml import parse_yaml_raw_as + +from llm_on_ray import common +from llm_on_ray.finetune.data_process import DataProcessor +from llm_on_ray.finetune.finetune_config import FinetuneConfig + + +def adapt_transformers_to_device(config: Dict): + device = config["Training"]["device"] + if device in ["hpu"]: + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + # adapt transformers to gaudi + adapt_transformers_to_gaudi() + + +def set_seed(config: Dict): + seed = config["Training"].get("seed", None) + if seed is None: + return + device = config["Training"]["device"] + if device in ["cpu", "gpu"]: + from accelerate.utils import set_seed as _set_seed + + _set_seed(seed) + elif device in ["hpu"]: + from optimum.habana.utils import set_seed as _set_seed + + _set_seed(seed) + + +def convert_to_training_args(cls, config: Dict): + device = config["Training"]["device"] + accelerate_mode = config["Training"]["accelerate_mode"] + save_strategy = config["General"]["save_strategy"] + + args = { + "output_dir": config["General"]["output_dir"], + "report_to": config["General"]["report_to"], + "resume_from_checkpoint": config["General"]["resume_from_checkpoint"], + "gradient_checkpointing": config["General"]["enable_gradient_checkpointing"], + "save_strategy": save_strategy if save_strategy != "False" else "no", + "bf16": config["Training"]["mixed_precision"] == "bf16", + "num_train_epochs": config["Training"]["epochs"], + "per_device_train_batch_size": config["Training"]["batch_size"], + "per_device_eval_batch_size": config["Training"]["batch_size"], + "optim": config["Training"]["optimizer"], + "learning_rate": config["Training"]["learning_rate"], + "logging_steps": config["Training"]["logging_steps"], + "lr_scheduler_type": config["Training"]["lr_scheduler"], + "weight_decay": config["Training"]["weight_decay"], + "gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"], + "do_train": True, + } + + # set attr do_eval + vf = config["Dataset"].get("validation_file", None) + vsp = config["Dataset"].get("validation_split_percentage", 0) + if vf is not None or (vsp / 100 > 0.0 and vsp / 100 < 1.0): + args.update({"do_eval": True}) + + # set attr max_steps + if config["Training"]["max_train_steps"] is not None: + args.update({"max_steps": config["Training"]["max_train_steps"]}) + + # set attr for device cpu + if device == "cpu": + if hasattr(cls, "use_cpu"): + args.update({"use_cpu": True}) + if hasattr(cls, "no_cuda"): + args.update({"no_cuda": True}) + args.update({"use_ipex": True}) + + # set attr 'deepspeed' + if accelerate_mode == "DEEPSPEED": + args.update({"deepspeed": config["Training"]["deepspeed_config_file"]}) + + # set attr for FSDP + # if accelerate_mode == "FSDP": + # args.updatwe({}) + + # set attr for Intel Gaudi + if device == "hpu": + args.update({"use_habana": True}) + args.update({"use_lazy_mode": config["Training"]["hpu_execution_mode"] == "lazy"}) + args.update({"pipelining_fwd_bwd": True}) + + return cls(**args) + + +def convert_dtype(dtype: str) -> Optional[torch.dtype]: + supported_dtypes = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "no": None, + } + return supported_dtypes[dtype] + + +def load_tokenizer(config: Dict): + if config["General"].get("tokenizer_name") is not None: + tokenizer_name = config["General"].get("tokenizer_name") + else: + tokenizer_name = config["General"]["base_model"] + load_config = config["General"].get("config", {}) + # default padding side is right + padding_side = config["Dataset"].get("padding_side", "right") + # default truncation side is right + truncation_side = config["Dataset"].get("truncation_side", "right") + tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer_name, padding_side=padding_side, truncation_side=truncation_side, **load_config + ) + return tokenizer + + +def load_dataset(config: Dict): + dataset_file = config["Dataset"].get("train_file", None) + if dataset_file is None: + return + + if os.path.exists(dataset_file): + # load from local file + def local_load(name, **load_config): + if os.path.isfile(name): + file = os.path.basename(os.path.abspath(name)) + path = os.path.dirname(os.path.abspath(name)) + dataset = datasets.load_dataset(path, data_files=file, **load_config) + else: + dataset = datasets.load_dataset(name, **load_config) + return dataset["train"] + + train_dataset = local_load(dataset_file) + validation_file = config["Dataset"].get("validation_file", None) + if validation_file is not None: + validation_dataset = local_load(validation_file) + return datasets.DatasetDict({"train": train_dataset, "validation": validation_dataset}) + + validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) + if validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0: + dataset_dict = train_dataset.train_test_split( + test_size=validation_split_percentage / 100 + ) + dataset_dict["validation"] = dataset_dict["test"] + return dataset_dict + + return datasets.DatasetDict({"train": train_dataset}) + else: + # try to download and load dataset from huggingface.co + load_config = config["General"].get("config", {}) + use_auth_token = load_config.get("use_auth_token", None) + raw_dataset = datasets.load_dataset(dataset_file, use_auth_token=use_auth_token) + + validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) + if "validation" not in raw_dataset.keys() and ( + validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0 + ): + dataset_dict = raw_dataset["train"].train_test_split( + test_size=validation_split_percentage / 100 + ) + dataset_dict["validation"] = dataset_dict["test"] + return dataset_dict + + return raw_dataset + + +def tokenize_dataset(config: Dict, tokenizer, dataset): + group = config["Dataset"].get("group", True) + block_size = config["Dataset"].get("block_size", 512) + tokenizer.pad_token = tokenizer.eos_token + + processor = DataProcessor(config, tokenizer) + + for key in dataset: + prompts = processor.make_prompt(dataset[key]) + dataset[key] = datasets.Dataset.from_dict(prompts) + + column_names = list(dataset["train"].features) + tokenize_fn = ( + processor.tokenize_by_neural_chat + if config["Dataset"].get("data_preprocess_type", "") == "neural_chat" + else processor.tokenize + ) + + tokenized_dataset = dataset.map( + tokenize_fn, + remove_columns=column_names, + batched=True, + load_from_cache_file=False, + desc="Tokenize dataset", + ) + + if group: + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + return result + + tokenized_dataset = tokenized_dataset.map( + group_texts, + batched=True, + load_from_cache_file=False, + desc=f"Grouping texts in chunks of {block_size}", + ) + + return tokenized_dataset + + +def prepare_data_collator(config: Dict, tokenizer): + return transformers.DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 + ) + + +def load_model(config: Dict): + model_name = config["General"]["base_model"] + model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no")) + model_config = config["General"].get("config", {}) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=model_dtype, **model_config + ) + + lora_config = config["General"].get("lora_config", None) + if lora_config: + peft_config = LoraConfig(**lora_config) + model = get_peft_model(model, peft_config) + + egc = config["General"].get("enable_gradient_checkpointing", False) + if egc: + model.enable_input_require_grads() + model.gradient_checkpointing_enable() + model.config.use_cache = False + + model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"])) + + return model + + +def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator): + device = config["Training"]["device"] + if device in ["cpu", "gpu"]: + from transformers import Trainer, TrainingArguments + + training_args = convert_to_training_args(TrainingArguments, config) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"] + if tokenized_dataset.get("validation") is not None + else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + return training_args, trainer + elif device in ["hpu"]: + from optimum.habana.transformers import GaudiTrainer + from optimum.habana.transformers import GaudiTrainingArguments + from optimum.habana import GaudiConfig + + # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config + gaudi_config_name = config["General"].get("gaudi_config_name", None) + if gaudi_config_name is not None: + gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name) + else: + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + training_args = convert_to_training_args(GaudiTrainingArguments, config) + trainer = GaudiTrainer( + model=model, + args=training_args, + gaudi_config=gaudi_config, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"] + if tokenized_dataset.get("validation") is not None + else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + return training_args, trainer + return None + + +def train_func(config: Dict[str, Any]): + os.chdir(config["cwd"]) + + adapt_transformers_to_device(config) + + set_seed(config) + + tokenizer = load_tokenizer(config) + + dataset = load_dataset(config) + + max_train_samples = config["Dataset"].get("max_train_samples", 0) + if 0 < max_train_samples < len(dataset["train"]): + dataset["train"] = dataset["train"].select(range(max_train_samples)) + + max_eval_samples = config["Dataset"].get("max_eval_samples", 0) + if "validation" in dataset and 0 < max_eval_samples < len(dataset["validation"]): + dataset["validation"] = dataset["validation"].select(range(max_eval_samples)) + + tokenized_dataset = tokenize_dataset(config, tokenizer, dataset) + + data_collator = prepare_data_collator(config, tokenizer) + + model = load_model(config) + + training_args, trainer = get_trainer(config, model, tokenizer, tokenized_dataset, data_collator) + + common.logger.info("train start") + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + common.logger.info("train finish") + + +def get_finetune_config(): + parser = argparse.ArgumentParser( + description="Finetune a transformers model on a causal language modeling task" + ) + parser.add_argument( + "--config_file", + type=str, + required=True, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + + # Print help if no arguments were provided + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + config_file = args.config_file + + with open(config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f) + return finetune_config.dict() + + +def main(external_config=None): + if not external_config: + config = get_finetune_config() + else: + config = external_config + + config["cwd"] = os.getcwd() + + num_training_workers = config["Training"].get("num_training_workers") + resources_per_worker = config["Training"].get("resources_per_worker") + + if num_training_workers > 1 and config["Training"].get("accelerate_mode", None) is None: + config["Training"][ + "accelerate_mode" + ] = "DDP" # will use DDP to accelerate if no method specified + + ccl_worker_count = 1 + device = config["Training"]["device"] + if device != "cpu": + ccl_worker_count = num_training_workers + + if not ray.is_initialized(): + runtime_env = { + "env_vars": { + "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), + "CCL_ZE_IPC_EXCHANGE": "sockets", + "CCL_WORKER_COUNT": str(ccl_worker_count), + "CCL_LOG_LEVEL": "info", + "FI_TCP_IFACE": "lo", + "FI_PROVIDER": "tcp", + } + } + + if config["General"]["gpt_base_model"] is True: + runtime_env["pip"] = ["transformers==4.26.0"] + + if device == "gpu": + num_cpus = ( + resources_per_worker["CPU"] * num_training_workers + 1 + ) # additional 1 for head worker + ray.init(num_cpus=num_cpus, runtime_env=runtime_env) + else: + ray.init(runtime_env=runtime_env) + + common.logger.info(f"ray available resources = {ray.available_resources()}") + use_gpu = True if device == "gpu" else False + scaling_config = ScalingConfig( + num_workers=num_training_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker, + placement_strategy="SPREAD", + ) + + # if try to use Intel GPU, convert device to 'xpu' + # due to accelerate internal use 'xpu' represent Intel GPU + if device == "gpu": + from accelerate.utils import is_xpu_available + + if is_xpu_available(): + device = "xpu" + + if config.get("torch_config", None) is None: + backend = None + if device == "cpu" or device == "xpu" or device == "gpu": + backend = "ccl" + elif device == "hpu": + backend = "hccl" + torch_config = common.TorchConfig(backend=backend, device=device) + else: + customer_torch_config = config.get("torch_config") + torch_config = common.TorchConfig(**customer_torch_config, device=device) + + if config.get("failure_config", None) is None: + failure_config = FailureConfig() + else: + customer_failure_config = config.get("failure_config") + failure_config = FailureConfig(**customer_failure_config) + + if config.get("run_config", None) is None: + run_config = RunConfig(failure_config=failure_config) + else: + customer_run_config = config.get("run_config") + if customer_run_config.get("failure_config", None) is None: + customer_run_config["failure_config"] = failure_config + run_config = RunConfig(**customer_run_config) + + trainer = TorchTrainer( + train_func, + train_loop_config=config, + scaling_config=scaling_config, + torch_config=torch_config, + run_config=run_config, + ) + results = trainer.fit() + if external_config is not None: + return results + + +if __name__ == "__main__": + main() diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py new file mode 100644 index 000000000..c9aada04b --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -0,0 +1,167 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pydantic import BaseModel, validator +from typing import Optional, List, Union + + +PRECISION_BF16 = "bf16" +PRECISION_FP16 = "fp16" +PRECISION_NO = "no" + +DEVICE_CPU = "cpu" +DEVICE_HPU = "hpu" +DEVICE_GPU = "gpu" + +ACCELERATE_STRATEGY_DDP = "DDP" +ACCELERATE_STRATEGY_FSDP = "FSDP" +ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED" + + +class GeneralConfig(BaseModel): + trust_remote_code: bool + use_auth_token: Optional[str] + + +class LoraConfig(BaseModel): + task_type: str + r: int + lora_alpha: int + lora_dropout: float + target_modules: Optional[List[str]] = None + + +class General(BaseModel): + base_model: str + tokenizer_name: Optional[str] = None + gaudi_config_name: Optional[str] = None + gpt_base_model: bool + output_dir: str + report_to: str = "none" + resume_from_checkpoint: Optional[str] = None + save_strategy: str = "no" + config: GeneralConfig + lora_config: Optional[LoraConfig] = None + enable_gradient_checkpointing: bool = False + + @validator("report_to") + def check_report_to(cls, v: str): + assert v in ["none", "tensorboard"] + return v + + +class Dataset(BaseModel): + train_file: str + validation_file: Optional[str] + validation_split_percentage: int + max_length: int = 512 + group: bool = True + block_size: int = 512 + shuffle: bool = False + max_source_length: int = 384 + padding_side: str = "right" + truncation_side: str = "right" + max_seq_length: int = 512 + truncation: bool = True + padding: bool = True + mask_input: bool = True + mask_response: bool = True + data_preprocess_type: str = "neural_chat" + max_train_samples: Optional[int] + max_eval_samples: Optional[int] + + +class RayResourceConfig(BaseModel): + CPU: int + GPU: int = 0 + HPU: int = 0 + + +class Training(BaseModel): + optimizer: str + batch_size: int + epochs: int + max_train_steps: Optional[int] = None + learning_rate: float + lr_scheduler: str + weight_decay: float + device: str = DEVICE_CPU + hpu_execution_mode: str = "lazy" + num_training_workers: int + resources_per_worker: RayResourceConfig + accelerate_mode: str = ACCELERATE_STRATEGY_DDP + mixed_precision: str = PRECISION_NO + gradient_accumulation_steps: int = 1 + logging_steps: int = 10 + deepspeed_config_file: str = "" + + @validator("device") + def check_device(cls, v: str): + # will convert to lower case + if v: + assert v.lower() in [DEVICE_CPU, DEVICE_GPU, DEVICE_HPU] + return v.lower() + + @validator("hpu_execution_mode") + def check_hpu_execution_mode(cls, v: str): + if v: + assert v in ["lazy", "eager", "eager.compile"] + return v + + @validator("accelerate_mode") + def check_accelerate_mode(cls, v: str): + if v: + assert v in [ + ACCELERATE_STRATEGY_DDP, + ACCELERATE_STRATEGY_FSDP, + ACCELERATE_STRATEGY_DEEPSPEED, + ] + return v + + @validator("mixed_precision") + def check_mixed_precision(cls, v: str): + if v: + assert v in [PRECISION_BF16, PRECISION_FP16, PRECISION_NO] + return v + + @validator("logging_steps") + def check_logging_steps(cls, v: int): + assert v > 0 + return v + + # @model_validator(mode='after') + # def check_device_and_accelerate_mode(self) -> "Training": + # dev = self.device + # res = self.resources_per_worker + # mode = self.accelerate_mode + # if dev == "CPU": + # if res.GPU is not None and res.GPU > 0: + # raise ValueError("Please not specified GPU resource when use CPU only in Ray.") + # if mode != "CPU_DDP": + # raise ValueError("Please specified CPU related accelerate mode when use CPU only in Ray.") + # elif dev == "GPU": + # if res.GPU is None or res.GPU == 0: + # raise ValueError("Please specified GPU resource when use GPU to fine tune in Ray.") + # if mode not in ["GPU_DDP", "GPU_FSDP"]: + # raise ValueError("Please speicifed GPU related accelerate mode when use GPU to fine tune in Ray.") + + # return self + + +class FinetuneConfig(BaseModel): + General: General + Dataset: Dataset + Training: Training diff --git a/comps/finetuning/models.py b/comps/finetuning/models.py new file mode 100644 index 000000000..f6757364d --- /dev/null +++ b/comps/finetuning/models.py @@ -0,0 +1,53 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel + + +class FineTuningJobsRequest(BaseModel): + training_file: str + model: str + + +class Hyperparameters(BaseModel): + n_epochs: int + batch_size: int + learning_rate_multiplier: float + + +class FineTuningJob(BaseModel): + object: str = "fine_tuning.job" # Set as constant + id: str + model: str + created_at: int + finished_at: int = None + fine_tuned_model: str = None + organization_id: str = None + result_files: List[str] = None + status: str + validation_file: str = None + training_file: str + hyperparameters: Hyperparameters + trained_tokens: int = None + integrations: List[str] = [] # Empty list by default + seed: int + estimated_finish: int = 0 # Set default value to 0 + + +class FineTuningJobList(BaseModel): + object: str = "list" # Set as constant + data: List[FineTuningJob] + has_more: bool + + +class FineTuningJobEvent(BaseModel): + object: str = "fine_tuning.job.event" # Set as constant + id: str + created_at: int + level: str + message: str + data: None = None # No data expected for this event type, set to None + type: str = "message" # Default event type is "message" diff --git a/comps/finetuning/models/llama-2-7b-chat-hf.yaml b/comps/finetuning/models/llama-2-7b-chat-hf.yaml new file mode 100644 index 000000000..3918196a2 --- /dev/null +++ b/comps/finetuning/models/llama-2-7b-chat-hf.yaml @@ -0,0 +1,39 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +General: + base_model: meta-llama/Llama-2-7b-chat-hf + gpt_base_model: false + output_dir: /tmp/llm-ray/output + save_strategy: no + config: + trust_remote_code: false + use_auth_token: null + lora_config: + task_type: CAUSAL_LM + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - v_proj + enable_gradient_checkpointing: false +Dataset: + train_file: examples/data/sample_finetune_data_small.jsonl + group: false + validation_file: null + validation_split_percentage: 5 +Training: + optimizer: adamw_torch + batch_size: 2 + epochs: 3 + learning_rate: 1.0e-05 + lr_scheduler: linear + weight_decay: 0.0 + mixed_precision: bf16 + device: cpu + num_training_workers: 1 + resources_per_worker: + CPU: 32 + gradient_accumulation_steps: 1 + logging_steps: 10 diff --git a/comps/finetuning/models/mistral-7b-v0.1.yaml b/comps/finetuning/models/mistral-7b-v0.1.yaml new file mode 100644 index 000000000..29d05de93 --- /dev/null +++ b/comps/finetuning/models/mistral-7b-v0.1.yaml @@ -0,0 +1,45 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +General: + base_model: mistralai/Mistral-7B-v0.1 + gpt_base_model: false + output_dir: /tmp/llm-ray/output + save_strategy: no + config: + trust_remote_code: false + use_auth_token: null + lora_config: + task_type: CAUSAL_LM + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - gate_proj + - up_proj + - down_proj + - lm_head + enable_gradient_checkpointing: false +Dataset: + train_file: examples/data/sample_finetune_data_small.jsonl + validation_file: null + validation_split_percentage: 5 +Training: + optimizer: adamw_torch + batch_size: 2 + epochs: 3 + learning_rate: 1.0e-05 + lr_scheduler: linear + weight_decay: 0.0 + mixed_precision: bf16 + device: cpu + num_training_workers: 2 + resources_per_worker: + CPU: 32 + accelerate_mode: DDP + gradient_accumulation_steps: 1 + logging_steps: 10 diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt new file mode 100644 index 000000000..cefb56399 --- /dev/null +++ b/comps/finetuning/requirements.txt @@ -0,0 +1,3 @@ +fastapi +pydantic +uvicorn From 9e9337b711e8bab9889d8d07035582ab82c6f0d9 Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Fri, 16 Aug 2024 10:43:16 +0800 Subject: [PATCH 02/13] refined readme. Signed-off-by: Xinyu Ye --- comps/finetuning/README.md | 31 ++++--- comps/finetuning/finetune_runner.py | 11 --- comps/finetuning/finetuning.py | 90 ------------------- comps/finetuning/finetuning_service.py | 40 +++++++++ .../llm_on_ray/finetune/finetune_config.py | 4 +- 5 files changed, 60 insertions(+), 116 deletions(-) delete mode 100644 comps/finetuning/finetuning.py create mode 100644 comps/finetuning/finetuning_service.py diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 3b3daf9ef..62703e8c4 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -14,35 +14,40 @@ pip install -r requirements.txt ### 1.2.1 Start Ray Cluster -TBD - -### 1.2.2 Start Finetuning Service +OneCCL and Intel MPI libraries should be dynamically linked in every node before Ray starts: +```bash +source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh +``` +Start Ray locally using the following command. ```bash -export RAY_ADDRESS="ray://${ray_head_ip}:10001" -python finetuning/finetuning.py +ray start --head ``` -# 🚀2. Consume Finetuning Service +For a multi-node cluster, start additional Ray worker nodes with below command. +```bash +ray start --address='${head_node_ip}:6379' +``` -## 2.1 Check Service Status +### 1.2.2 Start Finetuning Service ```bash -curl http://${your_ip}:8000/v1/health_check\ - -X GET \ - -H 'Content-Type: application/json' +export RAY_ADDRESS="ray://${ray_head_ip}:10001" +python finetuning/finetuning_service.py ``` -## 2.2 Create fine-tuning job +# 🚀2. Consume Finetuning Service + +## 2.1 Create fine-tuning job -Assuming a training file `file-vGxE9KywnSUkEL6dv9qZxKAF.jsonl` is uploaded, the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: +Assuming a training file `alpaca_data.json` is uploaded, the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: ```bash curl http://${your_ip}:8000/v1/fine_tuning/jobs \ -X POST \ -H "Content-Type: application/json" \ -d '{ - "training_file": "file-vGxE9KywnSUkEL6dv9qZxKAF.jsonl", + "training_file": "alpaca_data.json", "model": "meta-llama/Llama-2-7b-chat-hf" }' ``` diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py index 646341d73..fec53bf04 100644 --- a/comps/finetuning/finetune_runner.py +++ b/comps/finetuning/finetune_runner.py @@ -2,16 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import time -import uuid -from typing import List from llm_on_ray.finetune.finetune_config import FinetuneConfig from pydantic_yaml import parse_yaml_raw_as -from ray.train.base_trainer import TrainingFailedError -from ray.tune.callback import Callback -from ray.tune.experiment import Trial -from ray.tune.logger import LoggerCallback from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments @@ -38,10 +31,6 @@ def main(): from llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main llm_on_ray_finetune_main(finetune_config) - # try: - # llm_on_ray_finetune_main(finetune_config) - # except TrainingFailedError as e: - # print(e) if __name__ == "__main__": diff --git a/comps/finetuning/finetuning.py b/comps/finetuning/finetuning.py deleted file mode 100644 index a02833c00..000000000 --- a/comps/finetuning/finetuning.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import uvicorn -from fastapi import BackgroundTasks, Cookie, FastAPI, Form, Header, Response -from handlers import ( - handle_cancel_finetuning_job, - handle_create_finetuning_jobs, - handle_list_finetuning_jobs, - handle_retrieve_finetuning_job, -) -from models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest -from pydantic import BaseModel - -app = FastAPI() - - -@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob) -def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): - return handle_create_finetuning_jobs(request, background_tasks) - # return { - # "object": "fine_tuning.job", - # "id": "ftjob-abc123", - # "model": "davinci-002", - # "created_at": 1692661014, - # "finished_at": 1692661190, - # "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", - # "organization_id": "org-123", - # "result_files": ["file-abc123"], - # "status": "succeeded", - # "validation_file": None, - # "training_file": "file-abc123", - # "hyperparameters": { - # "n_epochs": 4, - # "batch_size": 1, - # "learning_rate_multiplier": 1.0, - # }, - # "trained_tokens": 5768, - # "integrations": [], - # "seed": 0, - # "estimated_finish": 0, - # } - - -@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList) -def list_finetuning_jobs(): - return handle_list_finetuning_jobs() - # return { - # "object": "list", - # "data": [ - # { - # "object": "fine_tuning.job", - # "id": "ftjob-abc123", - # "model": "davinci-002", - # "created_at": 1692661014, - # "finished_at": 1692661190, - # "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", - # "organization_id": "org-123", - # "result_files": ["file-abc123"], - # "status": "succeeded", - # "training_file": "file-abc123", - # "hyperparameters": { - # "n_epochs": 4, - # "batch_size": 1, - # "learning_rate_multiplier": 1.0, - # }, - # "trained_tokens": 5768, - # "integrations": [], - # "seed": 0, - # "estimated_finish": 0, - # }, - # ], - # "has_more": True, - # } - - -@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob) -def retrieve_finetuning_job(fine_tuning_job_id): - job = handle_retrieve_finetuning_job(fine_tuning_job_id) - return job - - -@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob) -def cancel_finetuning_job(fine_tuning_job_id): - job = handle_cancel_finetuning_job(fine_tuning_job_id) - return job - - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py new file mode 100644 index 000000000..6e15673ca --- /dev/null +++ b/comps/finetuning/finetuning_service.py @@ -0,0 +1,40 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import uvicorn +from fastapi import BackgroundTasks, FastAPI +from handlers import ( + handle_cancel_finetuning_job, + handle_create_finetuning_jobs, + handle_list_finetuning_jobs, + handle_retrieve_finetuning_job, +) +from models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest + +app = FastAPI() + + +@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob) +def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): + return handle_create_finetuning_jobs(request, background_tasks) + + +@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList) +def list_finetuning_jobs(): + return handle_list_finetuning_jobs() + + +@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob) +def retrieve_finetuning_job(fine_tuning_job_id): + job = handle_retrieve_finetuning_job(fine_tuning_job_id) + return job + + +@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob) +def cancel_finetuning_job(fine_tuning_job_id): + job = handle_cancel_finetuning_job(fine_tuning_job_id) + return job + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py index c9aada04b..c046b86d3 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune_config.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -80,8 +80,8 @@ class Dataset(BaseModel): mask_input: bool = True mask_response: bool = True data_preprocess_type: str = "neural_chat" - max_train_samples: Optional[int] - max_eval_samples: Optional[int] + max_train_samples: int = 0 + max_eval_samples: int = 0 class RayResourceConfig(BaseModel): From 86c1b4bd4caf5b6024b07e26a47cc03044f5f395 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 02:47:56 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/finetuning/README.md | 3 + .../llm_on_ray/common/torch_config.py | 21 +++---- .../llm_on_ray/finetune/data_process.py | 27 +++----- .../llm_on_ray/finetune/finetune.py | 61 ++++++------------- .../llm_on_ray/finetune/finetune_config.py | 4 +- 5 files changed, 40 insertions(+), 76 deletions(-) diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 62703e8c4..03d04787c 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -15,16 +15,19 @@ pip install -r requirements.txt ### 1.2.1 Start Ray Cluster OneCCL and Intel MPI libraries should be dynamically linked in every node before Ray starts: + ```bash source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh ``` Start Ray locally using the following command. + ```bash ray start --head ``` For a multi-node cluster, start additional Ray worker nodes with below command. + ```bash ray start --address='${head_node_ip}:6379' ``` diff --git a/comps/finetuning/llm_on_ray/common/torch_config.py b/comps/finetuning/llm_on_ray/common/torch_config.py index 115093ba3..522bf58ad 100644 --- a/comps/finetuning/llm_on_ray/common/torch_config.py +++ b/comps/finetuning/llm_on_ray/common/torch_config.py @@ -14,13 +14,14 @@ # limitations under the License. # -from ray.train.torch.config import _TorchBackend -from ray.train.torch.config import TorchConfig as RayTorchConfig -from ray.train._internal.worker_group import WorkerGroup -from dataclasses import dataclass -from typing import Optional import os import sys +from dataclasses import dataclass +from typing import Optional + +from ray.train._internal.worker_group import WorkerGroup +from ray.train.torch.config import TorchConfig as RayTorchConfig +from ray.train.torch.config import _TorchBackend # The package importlib_metadata is in a different place, depending on the Python version. if sys.version_info < (3, 8): @@ -40,7 +41,7 @@ def backend_cls(self): def xpu_libs_import(): - """try to import IPEX and oneCCL.""" + """Try to import IPEX and oneCCL.""" try: import intel_extension_for_pytorch except ImportError: @@ -56,7 +57,7 @@ def xpu_libs_import(): def hpu_libs_import(): - """try to import habana frameworkfs for torch""" + """Try to import habana frameworkfs for torch.""" try: import habana_frameworks.torch # noqa: F401 except ImportError as habana_not_exist: @@ -72,11 +73,7 @@ class EnableCCLBackend(_TorchBackend): device: Optional[str] = None def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): - libs_import = ( - hpu_libs_import - if self.device is not None and self.device.startswith("hpu") - else xpu_libs_import - ) + libs_import = hpu_libs_import if self.device is not None and self.device.startswith("hpu") else xpu_libs_import for i in range(len(worker_group)): worker_group.execute_single_async(i, libs_import) super().on_start(worker_group, backend_config) diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py index c617a4215..66d90bada 100644 --- a/comps/finetuning/llm_on_ray/finetune/data_process.py +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -28,7 +28,9 @@ class DataProcessor: def __init__(self, config, tokenizer): self.tokenizer = tokenizer self.end = tokenizer.eos_token - self.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request." + self.intro = ( + "Below is an instruction that describes a task. Write a response that appropriately completes the request." + ) self.instruction = "### Instruction:\n" self.input = "### Input:\n" self.response = "### Response:\n" @@ -68,16 +70,7 @@ def make_prompt(self, examples): ) prompts["prompt_sources"].append(prompt) else: - prompt = ( - self.intro - + self.end - + "\n" - + self.instruction - + instruction - + self.end - + "\n" - + self.response - ) + prompt = self.intro + self.end + "\n" + self.instruction + instruction + self.end + "\n" + self.response prompts["prompt_sources"].append(prompt) prompt_response = response + self.end prompts["prompt_targets"].append(prompt_response) @@ -119,9 +112,7 @@ def tokenize_by_neural_chat(self, examples): instruction, re.DOTALL, ) - convs_tokens = [ - self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs - ] + convs_tokens = [self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs] header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n") max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens) truncated_convs = self.__truncate_sequences(convs_tokens, max_input) @@ -129,14 +120,10 @@ def tokenize_by_neural_chat(self, examples): truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]] prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] - prompt_ids = [ - self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens - ] + prompt_ids = [self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens] prompt_ids = list(chain(*prompt_ids)) - resp_ids = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(response.strip()) - ) + resp_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(response.strip())) # keep last and eos_id max_resp = self.max_seq_length - len(prompt_ids) - 1 diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index 46714cdbb..c8a86bb5f 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -16,32 +16,26 @@ #!/usr/bin/env python -import os import argparse +import copy +import os import re import sys -import copy - -from typing import Any, Dict, Union, Optional - from itertools import chain +from typing import Any, Dict, Optional, Union -import torch import datasets -import transformers - -from peft import get_peft_model, LoraConfig - import ray -from ray.train.torch import TorchTrainer -from ray.air.config import ScalingConfig -from ray.air import RunConfig, FailureConfig - -from pydantic_yaml import parse_yaml_raw_as - +import torch +import transformers from llm_on_ray import common from llm_on_ray.finetune.data_process import DataProcessor from llm_on_ray.finetune.finetune_config import FinetuneConfig +from peft import LoraConfig, get_peft_model +from pydantic_yaml import parse_yaml_raw_as +from ray.air import FailureConfig, RunConfig +from ray.air.config import ScalingConfig +from ray.train.torch import TorchTrainer def adapt_transformers_to_device(config: Dict): @@ -176,9 +170,7 @@ def local_load(name, **load_config): validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) if validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0: - dataset_dict = train_dataset.train_test_split( - test_size=validation_split_percentage / 100 - ) + dataset_dict = train_dataset.train_test_split(test_size=validation_split_percentage / 100) dataset_dict["validation"] = dataset_dict["test"] return dataset_dict @@ -193,9 +185,7 @@ def local_load(name, **load_config): if "validation" not in raw_dataset.keys() and ( validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0 ): - dataset_dict = raw_dataset["train"].train_test_split( - test_size=validation_split_percentage / 100 - ) + dataset_dict = raw_dataset["train"].train_test_split(test_size=validation_split_percentage / 100) dataset_dict["validation"] = dataset_dict["test"] return dataset_dict @@ -265,9 +255,7 @@ def load_model(config: Dict): model_name = config["General"]["base_model"] model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no")) model_config = config["General"].get("config", {}) - model = transformers.AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=model_dtype, **model_config - ) + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config) lora_config = config["General"].get("lora_config", None) if lora_config: @@ -295,17 +283,14 @@ def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator model=model, args=training_args, train_dataset=tokenized_dataset["train"], - eval_dataset=tokenized_dataset["validation"] - if tokenized_dataset.get("validation") is not None - else None, + eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None, tokenizer=tokenizer, data_collator=data_collator, ) return training_args, trainer elif device in ["hpu"]: - from optimum.habana.transformers import GaudiTrainer - from optimum.habana.transformers import GaudiTrainingArguments from optimum.habana import GaudiConfig + from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config gaudi_config_name = config["General"].get("gaudi_config_name", None) @@ -322,9 +307,7 @@ def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator args=training_args, gaudi_config=gaudi_config, train_dataset=tokenized_dataset["train"], - eval_dataset=tokenized_dataset["validation"] - if tokenized_dataset.get("validation") is not None - else None, + eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None, tokenizer=tokenizer, data_collator=data_collator, ) @@ -366,9 +349,7 @@ def train_func(config: Dict[str, Any]): def get_finetune_config(): - parser = argparse.ArgumentParser( - description="Finetune a transformers model on a causal language modeling task" - ) + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") parser.add_argument( "--config_file", type=str, @@ -402,9 +383,7 @@ def main(external_config=None): resources_per_worker = config["Training"].get("resources_per_worker") if num_training_workers > 1 and config["Training"].get("accelerate_mode", None) is None: - config["Training"][ - "accelerate_mode" - ] = "DDP" # will use DDP to accelerate if no method specified + config["Training"]["accelerate_mode"] = "DDP" # will use DDP to accelerate if no method specified ccl_worker_count = 1 device = config["Training"]["device"] @@ -427,9 +406,7 @@ def main(external_config=None): runtime_env["pip"] = ["transformers==4.26.0"] if device == "gpu": - num_cpus = ( - resources_per_worker["CPU"] * num_training_workers + 1 - ) # additional 1 for head worker + num_cpus = resources_per_worker["CPU"] * num_training_workers + 1 # additional 1 for head worker ray.init(num_cpus=num_cpus, runtime_env=runtime_env) else: ray.init(runtime_env=runtime_env) diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py index c046b86d3..ba2a7671e 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune_config.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -14,9 +14,9 @@ # limitations under the License. # -from pydantic import BaseModel, validator -from typing import Optional, List, Union +from typing import List, Optional, Union +from pydantic import BaseModel, validator PRECISION_BF16 = "bf16" PRECISION_FP16 = "fp16" From 3b75a13c65b462adc9f2e242d25036e4e616e6ab Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Fri, 16 Aug 2024 16:33:13 +0800 Subject: [PATCH 04/13] refined code structure. Signed-off-by: Xinyu Ye --- comps/finetuning/README.md | 7 +++-- comps/finetuning/docker/Dockerfile.finetune | 22 ------------- comps/finetuning/docker/Dockerfile_cpu | 31 +++++++++++++++++++ comps/finetuning/finetune_runner.py | 4 +-- comps/finetuning/finetuning_service.py | 2 +- comps/finetuning/handlers.py | 6 ++-- .../finetuning/llm_on_ray/common/__init__.py | 4 +-- comps/finetuning/llm_on_ray/common/common.py | 2 +- .../llm_on_ray/finetune/finetune.py | 6 ++-- comps/finetuning/requirements.txt | 7 ++++- 10 files changed, 54 insertions(+), 37 deletions(-) delete mode 100644 comps/finetuning/docker/Dockerfile.finetune create mode 100644 comps/finetuning/docker/Dockerfile_cpu diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 03d04787c..dae128bcf 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -7,6 +7,9 @@ LLM Fine-tuning microservice involves adapting a base model to a specific task o ## 1.1 Install Requirements ```bash +python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +python -m pip install intel-extension-for-pytorch +python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ pip install -r requirements.txt ``` @@ -36,14 +39,14 @@ ray start --address='${head_node_ip}:6379' ```bash export RAY_ADDRESS="ray://${ray_head_ip}:10001" -python finetuning/finetuning_service.py +python finetuning_service.py ``` # 🚀2. Consume Finetuning Service ## 2.1 Create fine-tuning job -Assuming a training file `alpaca_data.json` is uploaded, the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: +Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: ```bash curl http://${your_ip}:8000/v1/fine_tuning/jobs \ diff --git a/comps/finetuning/docker/Dockerfile.finetune b/comps/finetuning/docker/Dockerfile.finetune deleted file mode 100644 index 30b9c6171..000000000 --- a/comps/finetuning/docker/Dockerfile.finetune +++ /dev/null @@ -1,22 +0,0 @@ -# Use the same python version with ray -FROM python:3.10.14 - -WORKDIR /root/opea-finetune - -RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \ - && apt-get install -y vim htop net-tools dnsutils \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* - -# COPY ./install-llm-on-ray.sh /tmp/install-llm-on-ray.sh -# RUN --mount=type=cache,target=/root/.cache/pip /tmp/install-llm-on-ray.sh - -COPY ./ . - -RUN --mount=type=cache,target=/root/.cache/pip cd ./llm-on-ray && pip install -v -e .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ - -RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir --upgrade -r requirements.txt - -RUN echo 'source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh' >> ~/.bashrc - -CMD ["bash", "-c", "./run.sh"] \ No newline at end of file diff --git a/comps/finetuning/docker/Dockerfile_cpu b/comps/finetuning/docker/Dockerfile_cpu new file mode 100644 index 000000000..e99c50d4d --- /dev/null +++ b/comps/finetuning/docker/Dockerfile_cpu @@ -0,0 +1,31 @@ +# Use the same python version with ray +FROM python:3.10.14 + +RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \ + && apt-get install -y vim htop net-tools dnsutils \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN useradd -m -s /bin/bash user && \ +mkdir -p /home/user && \ +chown -R user /home/user/ + +USER user + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip && \ + python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu && \ + python -m pip install intel-extension-for-pytorch && \ + python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ + pip install --no-cache-dir -r /home/user/comps/finetuning/requirements.txt + +RUN echo 'source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh' >> ~/.bashrc + +RUN ray start --head + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/embeddings/langchain + +CMD ["python", "finetuning_service.py"] \ No newline at end of file diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py index fec53bf04..ff9a8d8f8 100644 --- a/comps/finetuning/finetune_runner.py +++ b/comps/finetuning/finetune_runner.py @@ -3,7 +3,7 @@ import argparse -from llm_on_ray.finetune.finetune_config import FinetuneConfig +from .llm_on_ray.finetune.finetune_config import FinetuneConfig from pydantic_yaml import parse_yaml_raw_as from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments @@ -28,7 +28,7 @@ def main(): callback = FineTuneCallback() finetune_config["Training"]["callbacks"] = [callback] - from llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main + from .llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main llm_on_ray_finetune_main(finetune_config) diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index 6e15673ca..6dddfea9e 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -9,7 +9,7 @@ handle_list_finetuning_jobs, handle_retrieve_finetuning_job, ) -from models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest +from .models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest app = FastAPI() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index a874369f9..e00698269 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -8,9 +8,9 @@ from typing import Any, Dict, List, Set from fastapi import BackgroundTasks, HTTPException -from llm_on_ray.finetune.finetune import main -from llm_on_ray.finetune.finetune_config import FinetuneConfig -from models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest +from .llm_on_ray.finetune.finetune import main +from .llm_on_ray.finetune.finetune_config import FinetuneConfig +from .models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest from pydantic_yaml import parse_yaml_raw_as, to_yaml_file from ray.job_submission import JobSubmissionClient from ray.train.base_trainer import TrainingFailedError diff --git a/comps/finetuning/llm_on_ray/common/__init__.py b/comps/finetuning/llm_on_ray/common/__init__.py index a84df4482..97380de45 100644 --- a/comps/finetuning/llm_on_ray/common/__init__.py +++ b/comps/finetuning/llm_on_ray/common/__init__.py @@ -14,5 +14,5 @@ # limitations under the License. # -from llm_on_ray.common.logging import logger -from llm_on_ray.common.torch_config import TorchConfig +from .logging import logger +from .torch_config import TorchConfig diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py index 87f74096c..ccc9e2565 100644 --- a/comps/finetuning/llm_on_ray/common/common.py +++ b/comps/finetuning/llm_on_ray/common/common.py @@ -18,7 +18,7 @@ import importlib import os -from llm_on_ray.common.logging import logger +from .logging import logger def import_all_modules(basedir, prefix=None): diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index c8a86bb5f..bfb27974c 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -28,9 +28,9 @@ import ray import torch import transformers -from llm_on_ray import common -from llm_on_ray.finetune.data_process import DataProcessor -from llm_on_ray.finetune.finetune_config import FinetuneConfig +from ...llm_on_ray import common +from .data_process import DataProcessor +from .finetune_config import FinetuneConfig from peft import LoraConfig, get_peft_model from pydantic_yaml import parse_yaml_raw_as from ray.air import FailureConfig, RunConfig diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt index cefb56399..8aa19af33 100644 --- a/comps/finetuning/requirements.txt +++ b/comps/finetuning/requirements.txt @@ -1,3 +1,8 @@ +datasets fastapi +peft pydantic -uvicorn +pydantic_yaml +ray[all] +transformers +uvicorn \ No newline at end of file From 11ee6fb4a7e3edd45619034eb2d95b3af9d0d5e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 08:33:44 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/finetuning/finetune_runner.py | 3 ++- comps/finetuning/finetuning_service.py | 1 + comps/finetuning/handlers.py | 7 ++++--- comps/finetuning/llm_on_ray/finetune/finetune.py | 7 ++++--- comps/finetuning/requirements.txt | 2 +- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py index ff9a8d8f8..2fd00cb57 100644 --- a/comps/finetuning/finetune_runner.py +++ b/comps/finetuning/finetune_runner.py @@ -3,10 +3,11 @@ import argparse -from .llm_on_ray.finetune.finetune_config import FinetuneConfig from pydantic_yaml import parse_yaml_raw_as from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +from .llm_on_ray.finetune.finetune_config import FinetuneConfig + class FineTuneCallback(TrainerCallback): def __init__(self) -> None: diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index 6dddfea9e..c27ab724c 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -9,6 +9,7 @@ handle_list_finetuning_jobs, handle_retrieve_finetuning_job, ) + from .models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest app = FastAPI() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index e00698269..a7e6951ca 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -8,14 +8,15 @@ from typing import Any, Dict, List, Set from fastapi import BackgroundTasks, HTTPException -from .llm_on_ray.finetune.finetune import main -from .llm_on_ray.finetune.finetune_config import FinetuneConfig -from .models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest from pydantic_yaml import parse_yaml_raw_as, to_yaml_file from ray.job_submission import JobSubmissionClient from ray.train.base_trainer import TrainingFailedError from ray.tune.logger import LoggerCallback +from .llm_on_ray.finetune.finetune import main +from .llm_on_ray.finetune.finetune_config import FinetuneConfig +from .models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest + MODEL_CONFIG_FILE_MAP = { "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", "mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml", diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index bfb27974c..263226f0f 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -28,15 +28,16 @@ import ray import torch import transformers -from ...llm_on_ray import common -from .data_process import DataProcessor -from .finetune_config import FinetuneConfig from peft import LoraConfig, get_peft_model from pydantic_yaml import parse_yaml_raw_as from ray.air import FailureConfig, RunConfig from ray.air.config import ScalingConfig from ray.train.torch import TorchTrainer +from ...llm_on_ray import common +from .data_process import DataProcessor +from .finetune_config import FinetuneConfig + def adapt_transformers_to_device(config: Dict): device = config["Training"]["device"] diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt index 8aa19af33..188cac773 100644 --- a/comps/finetuning/requirements.txt +++ b/comps/finetuning/requirements.txt @@ -5,4 +5,4 @@ pydantic pydantic_yaml ray[all] transformers -uvicorn \ No newline at end of file +uvicorn From 21b0b021771c8b1d799caee8e27dd0012b17692a Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Mon, 19 Aug 2024 14:18:50 +0800 Subject: [PATCH 06/13] changed import relations. Signed-off-by: Xinyu Ye --- comps/finetuning/finetune_runner.py | 4 ++-- comps/finetuning/finetuning_service.py | 4 ++-- comps/finetuning/handlers.py | 7 ++----- comps/finetuning/llm_on_ray/finetune/finetune.py | 6 +++--- comps/finetuning/requirements.txt | 10 ++++++++++ 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py index 2fd00cb57..1ddfc4642 100644 --- a/comps/finetuning/finetune_runner.py +++ b/comps/finetuning/finetune_runner.py @@ -6,7 +6,7 @@ from pydantic_yaml import parse_yaml_raw_as from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments -from .llm_on_ray.finetune.finetune_config import FinetuneConfig +from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig class FineTuneCallback(TrainerCallback): @@ -29,7 +29,7 @@ def main(): callback = FineTuneCallback() finetune_config["Training"]["callbacks"] = [callback] - from .llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main + from comps.finetuning.llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main llm_on_ray_finetune_main(finetune_config) diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index c27ab724c..92fa42907 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -3,14 +3,14 @@ import uvicorn from fastapi import BackgroundTasks, FastAPI -from handlers import ( +from comps.finetuning.handlers import ( handle_cancel_finetuning_job, handle_create_finetuning_jobs, handle_list_finetuning_jobs, handle_retrieve_finetuning_job, ) -from .models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest +from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest app = FastAPI() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index a7e6951ca..0ae5c6c36 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -10,12 +10,9 @@ from fastapi import BackgroundTasks, HTTPException from pydantic_yaml import parse_yaml_raw_as, to_yaml_file from ray.job_submission import JobSubmissionClient -from ray.train.base_trainer import TrainingFailedError -from ray.tune.logger import LoggerCallback -from .llm_on_ray.finetune.finetune import main -from .llm_on_ray.finetune.finetune_config import FinetuneConfig -from .models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest +from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig +from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest MODEL_CONFIG_FILE_MAP = { "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index 263226f0f..57fe2e114 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -34,9 +34,9 @@ from ray.air.config import ScalingConfig from ray.train.torch import TorchTrainer -from ...llm_on_ray import common -from .data_process import DataProcessor -from .finetune_config import FinetuneConfig +from comps.finetuning.llm_on_ray import common +from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor +from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig def adapt_transformers_to_device(config: Dict): diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt index 188cac773..c8f2f37d1 100644 --- a/comps/finetuning/requirements.txt +++ b/comps/finetuning/requirements.txt @@ -1,8 +1,18 @@ +aiohttp datasets +docarray fastapi +httpx +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk peft +prometheus-fastapi-instrumentator pydantic pydantic_yaml +pyyaml ray[all] +requests +shortuuid transformers uvicorn From f3949ce05ca46b018ceff3aee20136142f68c487 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 06:23:48 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/finetuning/finetuning_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index 92fa42907..a1caba88d 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -3,13 +3,13 @@ import uvicorn from fastapi import BackgroundTasks, FastAPI + from comps.finetuning.handlers import ( handle_cancel_finetuning_job, handle_create_finetuning_jobs, handle_list_finetuning_jobs, handle_retrieve_finetuning_job, ) - from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest app = FastAPI() From 7d5da065b1fa23131a210aaa4f9e432e491393ab Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Tue, 20 Aug 2024 15:47:38 +0800 Subject: [PATCH 08/13] added docker for cpu part Signed-off-by: Xinyu Ye --- comps/finetuning/README.md | 27 ++++++++++++++++++--- comps/finetuning/docker/Dockerfile_cpu | 33 ++++++++++++++++---------- comps/finetuning/handlers.py | 2 ++ 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index dae128bcf..432121d4d 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -2,7 +2,7 @@ LLM Fine-tuning microservice involves adapting a base model to a specific task or dataset to improve its performance on that task. -# 🚀1. Start Microservice with Python +# 🚀1. Start Microservice with Python (Optional 1) ## 1.1 Install Requirements @@ -38,13 +38,34 @@ ray start --address='${head_node_ip}:6379' ### 1.2.2 Start Finetuning Service ```bash +export HF_TOKEN=${your_huggingface_token} export RAY_ADDRESS="ray://${ray_head_ip}:10001" python finetuning_service.py ``` -# 🚀2. Consume Finetuning Service +# 🚀2. Start Microservice with Docker (Optional 2) -## 2.1 Create fine-tuning job +## 2.1 Build Docker Image + +Build docker image with below command: + +```bash +export HF_TOKEN=${your_huggingface_token} +cd ../../ +docker build -t opea/finetuning:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy --build-arg HF_TOKEN=$HF_TOKEN -f comps/finetuning/docker/Dockerfile_cpu . +``` + +## 2.2 Run Docker with CLI + +Start docker container with below command: + +```bash +docker run -d --name="finetuning-server" -p 8000:8000 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest +``` + +# 🚀3. Consume Finetuning Service + +## 3.1 Create fine-tuning job Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: diff --git a/comps/finetuning/docker/Dockerfile_cpu b/comps/finetuning/docker/Dockerfile_cpu index e99c50d4d..1cb391af8 100644 --- a/comps/finetuning/docker/Dockerfile_cpu +++ b/comps/finetuning/docker/Dockerfile_cpu @@ -1,18 +1,23 @@ # Use the same python version with ray FROM python:3.10.14 -RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \ - && apt-get install -y vim htop net-tools dnsutils \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* +ARG HF_TOKEN + +ENV HF_TOKEN=$HF_TOKEN + +RUN apt-get update -y && apt-get install -y vim htop net-tools dnsutils RUN useradd -m -s /bin/bash user && \ -mkdir -p /home/user && \ -chown -R user /home/user/ + mkdir -p /home/user && \ + chown -R user /home/user/ + +COPY comps /home/user/comps + +RUN chown -R user /home/user/comps/finetuning USER user -COPY comps /home/user/comps +ENV PATH=$PATH:/home/user/.local/bin RUN pip install --no-cache-dir --upgrade pip && \ python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu && \ @@ -20,12 +25,14 @@ RUN pip install --no-cache-dir --upgrade pip && \ python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ pip install --no-cache-dir -r /home/user/comps/finetuning/requirements.txt -RUN echo 'source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh' >> ~/.bashrc - -RUN ray start --head - ENV PYTHONPATH=$PYTHONPATH:/home/user -WORKDIR /home/user/comps/embeddings/langchain +WORKDIR /home/user/comps/finetuning + +RUN echo PKGPATH=$(python3 -c "import pkg_resources; print(pkg_resources.get_distribution('oneccl-bind-pt').location)") >> run.sh && \ + echo 'export LD_LIBRARY_PATH=$PKGPATH/oneccl_bindings_for_pytorch/opt/mpi/lib/:$LD_LIBRARY_PATH' >> run.sh && \ + echo 'source $PKGPATH/oneccl_bindings_for_pytorch/env/setvars.sh' >> run.sh && \ + echo ray start --head >> run.sh && \ + echo python finetuning_service.py >> run.sh -CMD ["python", "finetuning_service.py"] \ No newline at end of file +CMD bash run.sh \ No newline at end of file diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index 0ae5c6c36..eaf241890 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -61,6 +61,8 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas finetune_config = parse_yaml_raw_as(FinetuneConfig, f) finetune_config.Dataset.train_file = train_file_path + if os.getenv("HF_TOKEN", None): + finetune_config.General.config.use_auth_token = os.getenv("HF_TOKEN", None) job = FineTuningJob( id=f"ft-job-{uuid.uuid4()}", From 1ff81daf016a34fbd338a100dbe2fbb942071327 Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:22:25 +0800 Subject: [PATCH 09/13] update finetuning api with openai format. (#535) * update finetuning api with openai format. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update doc and use 8001 port. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root --- comps/cores/mega/micro_service.py | 38 +++-- comps/cores/proto/api_protocol.py | 222 +++++++++++++++++++++++++ comps/finetuning/README.md | 18 +- comps/finetuning/finetuning_service.py | 32 ++-- comps/finetuning/handlers.py | 49 ++++-- comps/finetuning/models.py | 53 ------ 6 files changed, 307 insertions(+), 105 deletions(-) delete mode 100644 comps/finetuning/models.py diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index e83a2836b..689fff9dd 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -3,7 +3,7 @@ import asyncio import multiprocessing -from typing import Any, Optional, Type +from typing import Any, List, Optional, Type from ..proto.docarray import TextDoc from .constants import ServiceRoleType, ServiceType @@ -154,25 +154,27 @@ def register_microservice( output_datatype: Type[Any] = TextDoc, provider: Optional[str] = None, provider_endpoint: Optional[str] = None, + methods: List[str] = ["POST"], ): def decorator(func): - micro_service = MicroService( - name=name, - service_role=service_role, - service_type=service_type, - protocol=protocol, - host=host, - port=port, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - endpoint=endpoint, - input_datatype=input_datatype, - output_datatype=output_datatype, - provider=provider, - provider_endpoint=provider_endpoint, - ) - micro_service.app.router.add_api_route(endpoint, func, methods=["POST"]) - opea_microservices[name] = micro_service + if name not in opea_microservices: + micro_service = MicroService( + name=name, + service_role=service_role, + service_type=service_type, + protocol=protocol, + host=host, + port=port, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + endpoint=endpoint, + input_datatype=input_datatype, + output_datatype=output_datatype, + provider=provider, + provider_endpoint=provider_endpoint, + ) + opea_microservices[name] = micro_service + opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods) return func return decorator diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 957fc9d95..773bf56f2 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -279,3 +279,225 @@ def check_requests(request) -> Optional[JSONResponse]: ) return None + + +class Hyperparameters(BaseModel): + batch_size: Optional[Union[Literal["auto"], int]] = "auto" + """Number of examples in each batch. + + A larger batch size means that model parameters are updated less frequently, but with lower variance. + """ + + learning_rate_multiplier: Optional[Union[Literal["auto"], float]] = "auto" + """Scaling factor for the learning rate. + + A smaller learning rate may be useful to avoid overfitting. + """ + + n_epochs: Optional[Union[Literal["auto"], int]] = "auto" + """The number of epochs to train the model for. + + An epoch refers to one full cycle through the training dataset. "auto" decides + the optimal number of epochs based on the size of the dataset. If setting the + number manually, we support any number between 1 and 50 epochs. + """ + + +class FineTuningJobWandbIntegration(BaseModel): + project: str + """The name of the project that the new run will be created under.""" + + entity: Optional[str] = None + """The entity to use for the run. + + This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered + WandB API key is used. + """ + + name: Optional[str] = None + """A display name to set for the run. + + If not set, we will use the Job ID as the name. + """ + + tags: Optional[List[str]] = None + """A list of tags to be attached to the newly created run. + + These tags are passed through directly to WandB. Some default tags are generated + by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + """ + + +class FineTuningJobWandbIntegrationObject(BaseModel): + type: Literal["wandb"] + """The type of the integration being enabled for the fine-tuning job.""" + + wandb: FineTuningJobWandbIntegration + """The settings for your integration with Weights and Biases. + + This payload specifies the project that metrics will be sent to. Optionally, you + can set an explicit display name for your run, add tags to your run, and set a + default entity (team, username, etc) to be associated with your run. + """ + + +class FineTuningJobsRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/create + model: str + """The name of the model to fine-tune.""" + + training_file: str + """The ID of an uploaded file that contains training data.""" + + hyperparameters: Optional[Hyperparameters] = None + """The hyperparameters used for the fine-tuning job.""" + + suffix: Optional[str] = None + """A string of up to 64 characters that will be added to your fine-tuned model name.""" + + validation_file: Optional[str] = None + """The ID of an uploaded file that contains validation data.""" + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for your fine-tuning job.""" + + seed: Optional[str] = None + + +class Error(BaseModel): + code: str + """A machine-readable error code.""" + + message: str + """A human-readable error message.""" + + param: Optional[str] = None + """The parameter that was invalid, usually `training_file` or `validation_file`. + + This field will be null if the failure was not parameter-specific. + """ + + +class FineTuningJob(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/object + id: str + """The object identifier, which can be referenced in the API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the fine-tuning job was created.""" + + error: Optional[Error] = None + """For fine-tuning jobs that have `failed`, this will contain more information on + the cause of the failure.""" + + fine_tuned_model: Optional[str] = None + """The name of the fine-tuned model that is being created. + + The value will be null if the fine-tuning job is still running. + """ + + finished_at: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job was finished. + + The value will be null if the fine-tuning job is still running. + """ + + hyperparameters: Hyperparameters + """The hyperparameters used for the fine-tuning job. + + See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) + for more details. + """ + + model: str + """The base model that is being fine-tuned.""" + + object: Literal["fine_tuning.job"] = "fine_tuning.job" + """The object type, which is always "fine_tuning.job".""" + + organization_id: Optional[str] = None + """The organization that owns the fine-tuning job.""" + + result_files: List[str] = None + """The compiled results file ID(s) for the fine-tuning job. + + You can retrieve the results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + status: Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"] + """The current status of the fine-tuning job, which can be either + `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.""" + + trained_tokens: Optional[int] = None + """The total number of billable tokens processed by this fine-tuning job. + + The value will be null if the fine-tuning job is still running. + """ + + training_file: str + """The file ID used for training. + + You can retrieve the training data with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + validation_file: Optional[str] = None + """The file ID used for validation. + + You can retrieve the validation results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for this fine-tuning job.""" + + seed: Optional[int] = None + """The seed used for the fine-tuning job.""" + + estimated_finish: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job is estimated to + finish. + + The value will be null if the fine-tuning job is not running. + """ + + +class FineTuningJobIDRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/retrieve + # https://platform.openai.com/docs/api-reference/fine-tuning/cancel + fine_tuning_job_id: str + """The ID of the fine-tuning job.""" + + +class FineTuningJobListRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + after: Optional[str] = None + """Identifier for the last job from the previous pagination request.""" + + limit: Optional[int] = 20 + """Number of fine-tuning jobs to retrieve.""" + + +class FineTuningJobList(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + object: str = "list" + """The object type, which is always "list". + + This indicates that the returned data is a list of fine-tuning jobs. + """ + + data: List[FineTuningJob] + """A list containing FineTuningJob objects.""" + + has_more: bool + """Indicates whether there are more fine-tuning jobs beyond the current list. + + If true, additional requests can be made to retrieve more jobs. + """ diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 432121d4d..e56de1e82 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -60,7 +60,7 @@ docker build -t opea/finetuning:latest --build-arg https_proxy=$https_proxy --bu Start docker container with below command: ```bash -docker run -d --name="finetuning-server" -p 8000:8000 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest +docker run -d --name="finetuning-server" -p 8001:8001 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest ``` # 🚀3. Consume Finetuning Service @@ -70,11 +70,25 @@ docker run -d --name="finetuning-server" -p 8000:8000 --runtime=runc --ipc=host Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: ```bash -curl http://${your_ip}:8000/v1/fine_tuning/jobs \ +# create a finetuning job +curl http://${your_ip}:8001/v1/fine_tuning/jobs \ -X POST \ -H "Content-Type: application/json" \ -d '{ "training_file": "alpaca_data.json", "model": "meta-llama/Llama-2-7b-chat-hf" }' + +# list finetuning jobs +curl http://${your_ip}:8001/v1/fine_tuning/jobs -X GET + +# retrieve one finetuning job +curl http://localhost:8001/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{ + "fine_tuning_job_id": ${fine_tuning_job_id}}' + +# cancel one finetuning job + +curl http://localhost:8001/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{ + "fine_tuning_job_id": ${fine_tuning_job_id}}' + ``` diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index a1caba88d..2b0b3a91c 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -1,41 +1,45 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import uvicorn -from fastapi import BackgroundTasks, FastAPI +from fastapi import BackgroundTasks +from comps import opea_microservices, register_microservice +from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest from comps.finetuning.handlers import ( handle_cancel_finetuning_job, handle_create_finetuning_jobs, handle_list_finetuning_jobs, handle_retrieve_finetuning_job, ) -from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest -app = FastAPI() - -@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob) +@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001) def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): return handle_create_finetuning_jobs(request, background_tasks) -@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList) +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001, methods=["GET"] +) def list_finetuning_jobs(): return handle_list_finetuning_jobs() -@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob) -def retrieve_finetuning_job(fine_tuning_job_id): - job = handle_retrieve_finetuning_job(fine_tuning_job_id) +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/retrieve", host="0.0.0.0", port=8001 +) +def retrieve_finetuning_job(request: FineTuningJobIDRequest): + job = handle_retrieve_finetuning_job(request) return job -@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob) -def cancel_finetuning_job(fine_tuning_job_id): - job = handle_cancel_finetuning_job(fine_tuning_job_id) +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/cancel", host="0.0.0.0", port=8001 +) +def cancel_finetuning_job(request: FineTuningJobIDRequest): + job = handle_cancel_finetuning_job(request) return job if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) + opea_microservices["opea_service@finetuning"].start() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index eaf241890..5b842dffb 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -11,8 +11,13 @@ from pydantic_yaml import parse_yaml_raw_as, to_yaml_file from ray.job_submission import JobSubmissionClient +from comps.cores.proto.api_protocol import ( + FineTuningJob, + FineTuningJobIDRequest, + FineTuningJobList, + FineTuningJobsRequest, +) from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig -from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest MODEL_CONFIG_FILE_MAP = { "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", @@ -20,6 +25,12 @@ } DATASET_BASE_PATH = "datasets" +JOBS_PATH = "jobs" +if not os.path.exists(DATASET_BASE_PATH): + os.mkdir(DATASET_BASE_PATH) + +if not os.path.exists(JOBS_PATH): + os.mkdir(JOBS_PATH) FineTuningJobID = str CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs @@ -61,6 +72,17 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas finetune_config = parse_yaml_raw_as(FinetuneConfig, f) finetune_config.Dataset.train_file = train_file_path + + if request.hyperparameters is not None: + if request.hyperparameters.epochs != "auto": + finetune_config.Training.epochs = request.hyperparameters.epochs + + if request.hyperparameters.batch_size != "auto": + finetune_config.Training.batch_size = request.hyperparameters.batch_size + + if request.hyperparameters.learning_rate_multiplier != "auto": + finetune_config.Training.learning_rate = request.hyperparameters.learning_rate_multiplier + if os.getenv("HF_TOKEN", None): finetune_config.General.config.use_auth_token = os.getenv("HF_TOKEN", None) @@ -75,11 +97,10 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas "learning_rate_multiplier": finetune_config.Training.learning_rate, }, status="running", - # TODO: Add seed in finetune config - seed=random.randint(0, 1000), + seed=random.randint(0, 1000) if request.seed is None else request.seed, ) - finetune_config_file = f"jobs/{job.id}.yaml" + finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml" to_yaml_file(finetune_config_file, finetune_config) global ray_client @@ -107,14 +128,18 @@ def handle_list_finetuning_jobs(): return finetuning_jobs_list -def handle_retrieve_finetuning_job(fine_tuning_job_id): +def handle_retrieve_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + job = running_finetuning_jobs.get(fine_tuning_job_id) if job is None: raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") return job -def handle_cancel_finetuning_job(fine_tuning_job_id): +def handle_cancel_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id) if ray_job_id is None: raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") @@ -126,15 +151,3 @@ def handle_cancel_finetuning_job(fine_tuning_job_id): job = running_finetuning_jobs.get(fine_tuning_job_id) job.status = "cancelled" return job - - -# def cancel_all_jobs(): -# global ray_client -# ray_client = JobSubmissionClient() if ray_client is None else ray_client -# # stop all jobs -# for job_id in finetuning_job_to_ray_job.values(): -# ray_client.stop_job(job_id) - -# for job_id in running_finetuning_jobs: -# running_finetuning_jobs[job_id].status = "cancelled" -# return running_finetuning_jobs diff --git a/comps/finetuning/models.py b/comps/finetuning/models.py deleted file mode 100644 index f6757364d..000000000 --- a/comps/finetuning/models.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -from datetime import datetime -from typing import List, Optional - -from pydantic import BaseModel - - -class FineTuningJobsRequest(BaseModel): - training_file: str - model: str - - -class Hyperparameters(BaseModel): - n_epochs: int - batch_size: int - learning_rate_multiplier: float - - -class FineTuningJob(BaseModel): - object: str = "fine_tuning.job" # Set as constant - id: str - model: str - created_at: int - finished_at: int = None - fine_tuned_model: str = None - organization_id: str = None - result_files: List[str] = None - status: str - validation_file: str = None - training_file: str - hyperparameters: Hyperparameters - trained_tokens: int = None - integrations: List[str] = [] # Empty list by default - seed: int - estimated_finish: int = 0 # Set default value to 0 - - -class FineTuningJobList(BaseModel): - object: str = "list" # Set as constant - data: List[FineTuningJob] - has_more: bool - - -class FineTuningJobEvent(BaseModel): - object: str = "fine_tuning.job.event" # Set as constant - id: str - created_at: int - level: str - message: str - data: None = None # No data expected for this event type, set to None - type: str = "message" # Default event type is "message" From 4257281f55fe40d5f380e0fd7abe1ae0cd82725f Mon Sep 17 00:00:00 2001 From: Xinyu Ye Date: Wed, 21 Aug 2024 15:55:13 +0800 Subject: [PATCH 10/13] added upload_training_files and list_checkpoints. Signed-off-by: Xinyu Ye --- comps/finetuning/README.md | 7 +++- comps/finetuning/finetuning_service.py | 31 ++++++++++++++++- comps/finetuning/handlers.py | 33 +++++++++++++++++-- .../finetuning/models/llama-2-7b-chat-hf.yaml | 1 - comps/finetuning/models/mistral-7b-v0.1.yaml | 1 - comps/finetuning/requirements.txt | 1 + 6 files changed, 68 insertions(+), 6 deletions(-) diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index e56de1e82..6ca095bfe 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -39,7 +39,6 @@ ray start --address='${head_node_ip}:6379' ```bash export HF_TOKEN=${your_huggingface_token} -export RAY_ADDRESS="ray://${ray_head_ip}:10001" python finetuning_service.py ``` @@ -70,6 +69,9 @@ docker run -d --name="finetuning-server" -p 8001:8001 --runtime=runc --ipc=host Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: ```bash +# upload a training file +curl http://${your_ip}:8001/v1/finetune/upload_training_files -X POST -H "Content-Type: multipart/form-data" -F "files=@./alpaca_data.json" + # create a finetuning job curl http://${your_ip}:8001/v1/fine_tuning/jobs \ -X POST \ @@ -91,4 +93,7 @@ curl http://localhost:8001/v1/fine_tuning/jobs/retrieve -X POST -H "Content- curl http://localhost:8001/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{ "fine_tuning_job_id": ${fine_tuning_job_id}}' +# list checkpoints of a finetuning job +curl http://${your_ip}:8001/v1/finetune/list_checkpoints -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}' + ``` diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index 2b0b3a91c..c446d3a44 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -1,8 +1,11 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from fastapi import BackgroundTasks +import os +import urllib.parse +from fastapi import BackgroundTasks, File, UploadFile +from typing import List, Optional, Union from comps import opea_microservices, register_microservice from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest from comps.finetuning.handlers import ( @@ -10,6 +13,9 @@ handle_create_finetuning_jobs, handle_list_finetuning_jobs, handle_retrieve_finetuning_job, + save_content_to_local_disk, + handle_list_finetuning_checkpoints, + DATASET_BASE_PATH, ) @@ -41,5 +47,28 @@ def cancel_finetuning_job(request: FineTuningJobIDRequest): return job +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/finetune/upload_training_files", host="0.0.0.0", port=8001, +) +async def upload_training_files(files: Optional[Union[UploadFile, List[UploadFile]]] = File(None),): + if files: + if not isinstance(files, list): + files = [files] + for file in files: + filename = urllib.parse.quote(file.filename, safe="") + save_path = os.path.join(DATASET_BASE_PATH, filename) + await save_content_to_local_disk(save_path, file) + + return {"status": 200, "message": "Training files uploaded."} + + +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/finetune/list_checkpoints", host="0.0.0.0", port=8001 +) +def list_checkpoints(request: FineTuningJobIDRequest): + checkpoints = handle_list_finetuning_checkpoints(request) + return {"status": 200, "checkpoints":str(checkpoints)} + + if __name__ == "__main__": opea_microservices["opea_service@finetuning"].start() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index 5b842dffb..db780e280 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -5,7 +5,8 @@ import random import time import uuid -from typing import Any, Dict, List, Set +from pathlib import Path +from typing import Dict from fastapi import BackgroundTasks, HTTPException from pydantic_yaml import parse_yaml_raw_as, to_yaml_file @@ -99,7 +100,7 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas status="running", seed=random.randint(0, 1000) if request.seed is None else request.seed, ) - + finetune_config.General.output_dir = os.path.join(JOBS_PATH, job.id) finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml" to_yaml_file(finetune_config_file, finetune_config) @@ -151,3 +152,31 @@ def handle_cancel_finetuning_job(request: FineTuningJobIDRequest): job = running_finetuning_jobs.get(fine_tuning_job_id) job.status = "cancelled" return job + + +async def save_content_to_local_disk(save_path: str, content): + save_path = Path(save_path) + try: + if isinstance(content, str): + with open(save_path, "w", encoding="utf-8") as file: + file.write(content) + else: + with save_path.open("wb") as fout: + content = await content.read() + fout.write(content) + except Exception as e: + print(f"Write file failed. Exception: {e}") + raise Exception(status_code=500, detail=f"Write file {save_path} failed. Exception: {e}") + + +def handle_list_finetuning_checkpoints(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + output_dir = os.path.join(JOBS_PATH, job.id) + checkpoints = [] + if os.path.exists(output_dir): + checkpoints = os.listdir(output_dir) + return checkpoints \ No newline at end of file diff --git a/comps/finetuning/models/llama-2-7b-chat-hf.yaml b/comps/finetuning/models/llama-2-7b-chat-hf.yaml index 3918196a2..c09c60943 100644 --- a/comps/finetuning/models/llama-2-7b-chat-hf.yaml +++ b/comps/finetuning/models/llama-2-7b-chat-hf.yaml @@ -4,7 +4,6 @@ General: base_model: meta-llama/Llama-2-7b-chat-hf gpt_base_model: false - output_dir: /tmp/llm-ray/output save_strategy: no config: trust_remote_code: false diff --git a/comps/finetuning/models/mistral-7b-v0.1.yaml b/comps/finetuning/models/mistral-7b-v0.1.yaml index 29d05de93..81571356f 100644 --- a/comps/finetuning/models/mistral-7b-v0.1.yaml +++ b/comps/finetuning/models/mistral-7b-v0.1.yaml @@ -4,7 +4,6 @@ General: base_model: mistralai/Mistral-7B-v0.1 gpt_base_model: false - output_dir: /tmp/llm-ray/output save_strategy: no config: trust_remote_code: false diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt index c8f2f37d1..ece1ea253 100644 --- a/comps/finetuning/requirements.txt +++ b/comps/finetuning/requirements.txt @@ -10,6 +10,7 @@ peft prometheus-fastapi-instrumentator pydantic pydantic_yaml +python-multipart pyyaml ray[all] requests From 2b6787fc1a1eb9177766f685a47ab3588730cf5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:02:18 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/finetuning/finetuning_service.py | 18 ++++++++++++------ comps/finetuning/handlers.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index c446d3a44..fa11060af 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -3,19 +3,20 @@ import os import urllib.parse +from typing import List, Optional, Union from fastapi import BackgroundTasks, File, UploadFile -from typing import List, Optional, Union + from comps import opea_microservices, register_microservice from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest from comps.finetuning.handlers import ( + DATASET_BASE_PATH, handle_cancel_finetuning_job, handle_create_finetuning_jobs, + handle_list_finetuning_checkpoints, handle_list_finetuning_jobs, handle_retrieve_finetuning_job, save_content_to_local_disk, - handle_list_finetuning_checkpoints, - DATASET_BASE_PATH, ) @@ -48,9 +49,14 @@ def cancel_finetuning_job(request: FineTuningJobIDRequest): @register_microservice( - name="opea_service@finetuning", endpoint="/v1/finetune/upload_training_files", host="0.0.0.0", port=8001, + name="opea_service@finetuning", + endpoint="/v1/finetune/upload_training_files", + host="0.0.0.0", + port=8001, ) -async def upload_training_files(files: Optional[Union[UploadFile, List[UploadFile]]] = File(None),): +async def upload_training_files( + files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), +): if files: if not isinstance(files, list): files = [files] @@ -67,7 +73,7 @@ async def upload_training_files(files: Optional[Union[UploadFile, List[UploadFil ) def list_checkpoints(request: FineTuningJobIDRequest): checkpoints = handle_list_finetuning_checkpoints(request) - return {"status": 200, "checkpoints":str(checkpoints)} + return {"status": 200, "checkpoints": str(checkpoints)} if __name__ == "__main__": diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index db780e280..6a4317499 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -179,4 +179,4 @@ def handle_list_finetuning_checkpoints(request: FineTuningJobIDRequest): checkpoints = [] if os.path.exists(output_dir): checkpoints = os.listdir(output_dir) - return checkpoints \ No newline at end of file + return checkpoints From bd991e4879eaa9cde15fe3835e573107b520a544 Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Wed, 21 Aug 2024 23:23:48 +0800 Subject: [PATCH 12/13] Add finetuning gaudi (#544) * enable finetuning on Gaudi * add license Intel. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/finetuning/README.md | 40 ++++++++++++++----- comps/finetuning/docker/Dockerfile_hpu | 31 ++++++++++++++ comps/finetuning/finetuning_service.py | 12 +++--- comps/finetuning/handlers.py | 4 ++ comps/finetuning/lanuch.sh | 12 ++++++ .../finetuning/llm_on_ray/common/__init__.py | 2 + comps/finetuning/llm_on_ray/common/common.py | 2 + .../llm_on_ray/common/torch_config.py | 2 + .../llm_on_ray/finetune/__init__.py | 2 + .../llm_on_ray/finetune/data_process.py | 2 + .../llm_on_ray/finetune/finetune.py | 2 + .../llm_on_ray/finetune/finetune_config.py | 2 + .../finetuning/models/llama-2-7b-chat-hf.yaml | 1 + comps/finetuning/models/mistral-7b-v0.1.yaml | 1 + comps/finetuning/requirements.txt | 2 +- 15 files changed, 101 insertions(+), 16 deletions(-) create mode 100644 comps/finetuning/docker/Dockerfile_hpu create mode 100644 comps/finetuning/lanuch.sh diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 6ca095bfe..14ae91180 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -44,7 +44,9 @@ python finetuning_service.py # 🚀2. Start Microservice with Docker (Optional 2) -## 2.1 Build Docker Image +## 2.1 Setup on CPU + +### 2.1.1 Build Docker Image Build docker image with below command: @@ -54,12 +56,32 @@ cd ../../ docker build -t opea/finetuning:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy --build-arg HF_TOKEN=$HF_TOKEN -f comps/finetuning/docker/Dockerfile_cpu . ``` -## 2.2 Run Docker with CLI +### 2.1.2 Run Docker with CLI + +Start docker container with below command: + +```bash +docker run -d --name="finetuning-server" -p 8005:8005 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest +``` + +## 2.2 Setup on Gaudi2 + +### 2.2.1 Build Docker Image + +Build docker image with below command: + +```bash +cd ../../ +docker build -t opea/finetuning-gaudi:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/finetuning/docker/Dockerfile_hpu . +``` + +### 2.2.2 Run Docker with CLI Start docker container with below command: ```bash -docker run -d --name="finetuning-server" -p 8001:8001 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest +export HF_TOKEN=${your_huggingface_token} +docker run --runtime=habana -e HABANA_VISIBLE_DEVICES=all -p 8005:8005 -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host -e https_proxy=$https_proxy -e http_proxy=$http_proxy -e no_proxy=$no_proxy -e HF_TOKEN="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" opea/finetuning-gaudi:latest ``` # 🚀3. Consume Finetuning Service @@ -70,10 +92,10 @@ Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in ```bash # upload a training file -curl http://${your_ip}:8001/v1/finetune/upload_training_files -X POST -H "Content-Type: multipart/form-data" -F "files=@./alpaca_data.json" +curl http://${your_ip}:8005/v1/finetune/upload_training_files -X POST -H "Content-Type: multipart/form-data" -F "files=@./alpaca_data.json" # create a finetuning job -curl http://${your_ip}:8001/v1/fine_tuning/jobs \ +curl http://${your_ip}:8005/v1/fine_tuning/jobs \ -X POST \ -H "Content-Type: application/json" \ -d '{ @@ -82,18 +104,18 @@ curl http://${your_ip}:8001/v1/fine_tuning/jobs \ }' # list finetuning jobs -curl http://${your_ip}:8001/v1/fine_tuning/jobs -X GET +curl http://${your_ip}:8005/v1/fine_tuning/jobs -X GET # retrieve one finetuning job -curl http://localhost:8001/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{ +curl http://localhost:8005/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{ "fine_tuning_job_id": ${fine_tuning_job_id}}' # cancel one finetuning job -curl http://localhost:8001/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{ +curl http://localhost:8005/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{ "fine_tuning_job_id": ${fine_tuning_job_id}}' # list checkpoints of a finetuning job -curl http://${your_ip}:8001/v1/finetune/list_checkpoints -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}' +curl http://${your_ip}:8005/v1/finetune/list_checkpoints -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}' ``` diff --git a/comps/finetuning/docker/Dockerfile_hpu b/comps/finetuning/docker/Dockerfile_hpu new file mode 100644 index 000000000..1277d76c1 --- /dev/null +++ b/comps/finetuning/docker/Dockerfile_hpu @@ -0,0 +1,31 @@ +# Use the same python version with ray +FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + +ENV DEVICE="hpu" + +RUN apt-get update -y && apt-get install -y vim htop net-tools dnsutils + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +COPY comps /home/user/comps + +RUN chown -R user /home/user/comps/finetuning + +USER user + +ENV PATH=$PATH:/home/user/.local/bin + +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r /home/user/comps/finetuning/requirements.txt && \ + pip install --no-cache-dir optimum-habana + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/finetuning + +ENTRYPOINT ["/bin/bash", "launch.sh"] + +# CMD ["/bin/bash"] + diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index fa11060af..fabb32bc4 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -20,20 +20,20 @@ ) -@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001) +@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8005) def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): return handle_create_finetuning_jobs(request, background_tasks) @register_microservice( - name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001, methods=["GET"] + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8005, methods=["GET"] ) def list_finetuning_jobs(): return handle_list_finetuning_jobs() @register_microservice( - name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/retrieve", host="0.0.0.0", port=8001 + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/retrieve", host="0.0.0.0", port=8005 ) def retrieve_finetuning_job(request: FineTuningJobIDRequest): job = handle_retrieve_finetuning_job(request) @@ -41,7 +41,7 @@ def retrieve_finetuning_job(request: FineTuningJobIDRequest): @register_microservice( - name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/cancel", host="0.0.0.0", port=8001 + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/cancel", host="0.0.0.0", port=8005 ) def cancel_finetuning_job(request: FineTuningJobIDRequest): job = handle_cancel_finetuning_job(request) @@ -52,7 +52,7 @@ def cancel_finetuning_job(request: FineTuningJobIDRequest): name="opea_service@finetuning", endpoint="/v1/finetune/upload_training_files", host="0.0.0.0", - port=8001, + port=8005, ) async def upload_training_files( files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), @@ -69,7 +69,7 @@ async def upload_training_files( @register_microservice( - name="opea_service@finetuning", endpoint="/v1/finetune/list_checkpoints", host="0.0.0.0", port=8001 + name="opea_service@finetuning", endpoint="/v1/finetune/list_checkpoints", host="0.0.0.0", port=8005 ) def list_checkpoints(request: FineTuningJobIDRequest): checkpoints = handle_list_finetuning_checkpoints(request) diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index 6a4317499..2bdab42a9 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -101,6 +101,10 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas seed=random.randint(0, 1000) if request.seed is None else request.seed, ) finetune_config.General.output_dir = os.path.join(JOBS_PATH, job.id) + if os.getenv("DEVICE", ""): + print(f"specific device: {os.getenv('DEVICE')}") + finetune_config.Training.device = os.getenv("DEVICE") + finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml" to_yaml_file(finetune_config_file, finetune_config) diff --git a/comps/finetuning/lanuch.sh b/comps/finetuning/lanuch.sh new file mode 100644 index 000000000..a7e249b6f --- /dev/null +++ b/comps/finetuning/lanuch.sh @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +if [[ -n "$RAY_PORT" ]];then + export RAY_ADDRESS=http://127.0.0.1:$RAY_PORT + ray start --head --port $RAY_PORT +else + export RAY_ADDRESS=http://127.0.0.1:8265 + ray start --head +fi + +python finetuning_service.py diff --git a/comps/finetuning/llm_on_ray/common/__init__.py b/comps/finetuning/llm_on_ray/common/__init__.py index 97380de45..0bd6a7f0e 100644 --- a/comps/finetuning/llm_on_ray/common/__init__.py +++ b/comps/finetuning/llm_on_ray/common/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py index ccc9e2565..e75e2e854 100644 --- a/comps/finetuning/llm_on_ray/common/common.py +++ b/comps/finetuning/llm_on_ray/common/common.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/llm_on_ray/common/torch_config.py b/comps/finetuning/llm_on_ray/common/torch_config.py index 522bf58ad..744bbbc0c 100644 --- a/comps/finetuning/llm_on_ray/common/torch_config.py +++ b/comps/finetuning/llm_on_ray/common/torch_config.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/llm_on_ray/finetune/__init__.py b/comps/finetuning/llm_on_ray/finetune/__init__.py index 854e39ad4..42cdaf68b 100644 --- a/comps/finetuning/llm_on_ray/finetune/__init__.py +++ b/comps/finetuning/llm_on_ray/finetune/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py index 66d90bada..30b695da9 100644 --- a/comps/finetuning/llm_on_ray/finetune/data_process.py +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index 57fe2e114..a1fff0004 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py index ba2a7671e..1a191d360 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune_config.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. # diff --git a/comps/finetuning/models/llama-2-7b-chat-hf.yaml b/comps/finetuning/models/llama-2-7b-chat-hf.yaml index c09c60943..d6ae5f34d 100644 --- a/comps/finetuning/models/llama-2-7b-chat-hf.yaml +++ b/comps/finetuning/models/llama-2-7b-chat-hf.yaml @@ -3,6 +3,7 @@ General: base_model: meta-llama/Llama-2-7b-chat-hf + output_dir: "./tmp" gpt_base_model: false save_strategy: no config: diff --git a/comps/finetuning/models/mistral-7b-v0.1.yaml b/comps/finetuning/models/mistral-7b-v0.1.yaml index 81571356f..4334fa37e 100644 --- a/comps/finetuning/models/mistral-7b-v0.1.yaml +++ b/comps/finetuning/models/mistral-7b-v0.1.yaml @@ -3,6 +3,7 @@ General: base_model: mistralai/Mistral-7B-v0.1 + output_dir: "./tmp" gpt_base_model: false save_strategy: no config: diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt index ece1ea253..4255a3716 100644 --- a/comps/finetuning/requirements.txt +++ b/comps/finetuning/requirements.txt @@ -8,7 +8,7 @@ opentelemetry-exporter-otlp opentelemetry-sdk peft prometheus-fastapi-instrumentator -pydantic +pydantic==2.8.2 pydantic_yaml python-multipart pyyaml From a73c0c4bc567e5d1eae70e35c0421e381a72df4e Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Wed, 21 Aug 2024 23:33:15 +0800 Subject: [PATCH 13/13] fix license and typo (#545) * remove token. * fix license. --------- Co-authored-by: root --- comps/finetuning/README.md | 2 +- comps/finetuning/llm_on_ray/common/__init__.py | 13 ------------- comps/finetuning/llm_on_ray/common/common.py | 13 ------------- comps/finetuning/llm_on_ray/common/logging.py | 15 ++------------- .../finetuning/llm_on_ray/common/torch_config.py | 13 ------------- comps/finetuning/llm_on_ray/finetune/__init__.py | 13 ------------- .../llm_on_ray/finetune/data_process.py | 13 ------------- comps/finetuning/llm_on_ray/finetune/finetune.py | 13 ------------- .../llm_on_ray/finetune/finetune_config.py | 13 ------------- 9 files changed, 3 insertions(+), 105 deletions(-) diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 14ae91180..411395ec9 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -81,7 +81,7 @@ Start docker container with below command: ```bash export HF_TOKEN=${your_huggingface_token} -docker run --runtime=habana -e HABANA_VISIBLE_DEVICES=all -p 8005:8005 -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host -e https_proxy=$https_proxy -e http_proxy=$http_proxy -e no_proxy=$no_proxy -e HF_TOKEN="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" opea/finetuning-gaudi:latest +docker run --runtime=habana -e HABANA_VISIBLE_DEVICES=all -p 8005:8005 -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host -e https_proxy=$https_proxy -e http_proxy=$http_proxy -e no_proxy=$no_proxy -e HF_TOKEN=$HF_TOKEN opea/finetuning-gaudi:latest ``` # 🚀3. Consume Finetuning Service diff --git a/comps/finetuning/llm_on_ray/common/__init__.py b/comps/finetuning/llm_on_ray/common/__init__.py index 0bd6a7f0e..a4ad1e878 100644 --- a/comps/finetuning/llm_on_ray/common/__init__.py +++ b/comps/finetuning/llm_on_ray/common/__init__.py @@ -2,19 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# from .logging import logger from .torch_config import TorchConfig diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py index e75e2e854..136d2526f 100644 --- a/comps/finetuning/llm_on_ray/common/common.py +++ b/comps/finetuning/llm_on_ray/common/common.py @@ -2,19 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# import glob import importlib diff --git a/comps/finetuning/llm_on_ray/common/logging.py b/comps/finetuning/llm_on_ray/common/logging.py index 6d3f6ae80..e2aec567a 100644 --- a/comps/finetuning/llm_on_ray/common/logging.py +++ b/comps/finetuning/llm_on_ray/common/logging.py @@ -1,18 +1,7 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# import functools import logging diff --git a/comps/finetuning/llm_on_ray/common/torch_config.py b/comps/finetuning/llm_on_ray/common/torch_config.py index 744bbbc0c..9e3f48a7c 100644 --- a/comps/finetuning/llm_on_ray/common/torch_config.py +++ b/comps/finetuning/llm_on_ray/common/torch_config.py @@ -2,19 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# import os import sys diff --git a/comps/finetuning/llm_on_ray/finetune/__init__.py b/comps/finetuning/llm_on_ray/finetune/__init__.py index 42cdaf68b..0262e494a 100644 --- a/comps/finetuning/llm_on_ray/finetune/__init__.py +++ b/comps/finetuning/llm_on_ray/finetune/__init__.py @@ -2,16 +2,3 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py index 30b695da9..ab5efcc09 100644 --- a/comps/finetuning/llm_on_ray/finetune/data_process.py +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -2,19 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# import copy import re diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index a1fff0004..f268800f2 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -2,19 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# #!/usr/bin/env python diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py index 1a191d360..391c6e6c8 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune_config.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -2,19 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Copyright 2023 The LLM-on-Ray Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# from typing import List, Optional, Union