Created
October 6, 2024 13:48
-
-
Save almugabo/12d0f4ae745ce7da5c9f5a19073f3976 to your computer and use it in GitHub Desktop.
Torchtune wrappers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## this gist include some scripts to work with Torchtune | |
| from huggingface_hub import snapshot_download | |
| from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError | |
| from pathlib import Path | |
| def download_model(model_id: str, output_dir: str = None, hf_token: str = None, ignore_patterns: list = ["*.safetensors"]) -> None: | |
| """ | |
| Downloads a model from the Hugging Face Hub. | |
| Args: | |
| model_id (str): The repository ID on Hugging Face (e.g., 'meta-llama/Llama-2-7b-hf'). | |
| output_dir (str, optional): Directory to save the model files. Defaults to `/tmp/<model_name>`. | |
| hf_token (str, optional): Hugging Face token for gated models. | |
| ignore_patterns (list, optional): List of file patterns to ignore. Defaults to ['*.safetensors']. | |
| Returns: | |
| None | |
| """ | |
| # Default output_dir is `/tmp/<model_name>` if not specified | |
| if output_dir is None: | |
| model_name = model_id.split("/")[-1] | |
| output_dir = Path("/tmp") / model_name | |
| else: | |
| output_dir = Path(output_dir) | |
| # Inform the user of ignored patterns | |
| if ignore_patterns: | |
| print(f"Ignoring files matching the following patterns: {ignore_patterns}") | |
| try: | |
| # Use Hugging Face's snapshot_download to pull files | |
| true_output_dir = snapshot_download( | |
| model_id, | |
| local_dir=output_dir, | |
| local_dir_use_symlinks="auto", # Using 'auto' to decide on symlinks | |
| ignore_patterns=ignore_patterns, | |
| token=hf_token, | |
| ) | |
| except GatedRepoError: | |
| raise Exception( | |
| "It looks like you are trying to access a gated repository. Please ensure you " | |
| "have access to the repository and provide the proper Hugging Face API token." | |
| ) | |
| except RepositoryNotFoundError: | |
| raise Exception(f"Repository '{model_id}' not found on the Hugging Face Hub.") | |
| except Exception as e: | |
| raise Exception(f"Failed to download {model_id} with error: {str(e)}") | |
| print( | |
| f"Successfully downloaded model repo and wrote to the following location: {true_output_dir}" | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment