Skip to content

Instantly share code, notes, and snippets.

@almugabo
Created October 6, 2024 13:48
Show Gist options
  • Select an option

  • Save almugabo/12d0f4ae745ce7da5c9f5a19073f3976 to your computer and use it in GitHub Desktop.

Select an option

Save almugabo/12d0f4ae745ce7da5c9f5a19073f3976 to your computer and use it in GitHub Desktop.
Torchtune wrappers
## this gist include some scripts to work with Torchtune
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from pathlib import Path
def download_model(model_id: str, output_dir: str = None, hf_token: str = None, ignore_patterns: list = ["*.safetensors"]) -> None:
"""
Downloads a model from the Hugging Face Hub.
Args:
model_id (str): The repository ID on Hugging Face (e.g., 'meta-llama/Llama-2-7b-hf').
output_dir (str, optional): Directory to save the model files. Defaults to `/tmp/<model_name>`.
hf_token (str, optional): Hugging Face token for gated models.
ignore_patterns (list, optional): List of file patterns to ignore. Defaults to ['*.safetensors'].
Returns:
None
"""
# Default output_dir is `/tmp/<model_name>` if not specified
if output_dir is None:
model_name = model_id.split("/")[-1]
output_dir = Path("/tmp") / model_name
else:
output_dir = Path(output_dir)
# Inform the user of ignored patterns
if ignore_patterns:
print(f"Ignoring files matching the following patterns: {ignore_patterns}")
try:
# Use Hugging Face's snapshot_download to pull files
true_output_dir = snapshot_download(
model_id,
local_dir=output_dir,
local_dir_use_symlinks="auto", # Using 'auto' to decide on symlinks
ignore_patterns=ignore_patterns,
token=hf_token,
)
except GatedRepoError:
raise Exception(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and provide the proper Hugging Face API token."
)
except RepositoryNotFoundError:
raise Exception(f"Repository '{model_id}' not found on the Hugging Face Hub.")
except Exception as e:
raise Exception(f"Failed to download {model_id} with error: {str(e)}")
print(
f"Successfully downloaded model repo and wrote to the following location: {true_output_dir}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment