Skip to content

Instantly share code, notes, and snippets.

@anapaulagomes
Last active August 25, 2025 15:39
Show Gist options
  • Select an option

  • Save anapaulagomes/a48c4c8844e0940c1de7a52c828ca670 to your computer and use it in GitHub Desktop.

Select an option

Save anapaulagomes/a48c4c8844e0940c1de7a52c828ca670 to your computer and use it in GitHub Desktop.
import os
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
"""
Mandatory: downloading HugginFace Hub CLI https://huggingface.co/docs/huggingface_hub/v0.19.1/en/guides/cli
Usage: hf download <wanted model>
Env vars:
- HF_HOME: use the default dir or configure this with a path to huggingface weights.
- HF_FROM_LIMITED_ENV: you can set this with True when running your code from HPC or other limited envs.
If false, it will return the model name and run as usual.
In your code, you can import `get_or_download_model` to load your local weights.
Usage:
model_name = "<wanted model>"
tokenizer = AutoTokenizer.from_pretrained(get_or_download_model(model_name))
model = AutoModel.from_pretrained(get_or_download_model(model_name))
"""
HF_DEFAULT_HOME = Path(os.getenv("HF_HOME", "~/.cache/huggingface"))
def get_or_download_model(
repo_id: str,
revision: str = "main",
) -> Path:
"""
Parse model name to locally stored weights. It will download the model from the library only when in limited environment.
Args:
model_ref (str) : Model reference containing org_name/model_name such as 'meta-llama/Llama-2-7b-chat-hf'.
revision (str): Model revision branch. Defaults to 'main'.
Returns:
str: path to model weights within model directory
"""
if os.getenv("HF_FROM_LIMITED_ENV") is None:
logger.info("Regular environment detected.")
return repo_id
models_dir = HF_DEFAULT_HOME / "hub"
cached_name = f'models--{repo_id.replace("/", "--")}'
cached_dir = models_dir / cached_name
if not cached_dir.exists():
snapshot_download(repo_id=repo_id, revision=revision, cache_dir=models_dir)
snapshot_hash = (cached_dir / "refs" / revision).read_text()
return cached_dir / "snapshots" / snapshot_hash
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment