Last active
August 25, 2025 15:39
-
-
Save anapaulagomes/a48c4c8844e0940c1de7a52c828ca670 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| from huggingface_hub import snapshot_download | |
| """ | |
| Mandatory: downloading HugginFace Hub CLI https://huggingface.co/docs/huggingface_hub/v0.19.1/en/guides/cli | |
| Usage: hf download <wanted model> | |
| Env vars: | |
| - HF_HOME: use the default dir or configure this with a path to huggingface weights. | |
| - HF_FROM_LIMITED_ENV: you can set this with True when running your code from HPC or other limited envs. | |
| If false, it will return the model name and run as usual. | |
| In your code, you can import `get_or_download_model` to load your local weights. | |
| Usage: | |
| model_name = "<wanted model>" | |
| tokenizer = AutoTokenizer.from_pretrained(get_or_download_model(model_name)) | |
| model = AutoModel.from_pretrained(get_or_download_model(model_name)) | |
| """ | |
| HF_DEFAULT_HOME = Path(os.getenv("HF_HOME", "~/.cache/huggingface")) | |
| def get_or_download_model( | |
| repo_id: str, | |
| revision: str = "main", | |
| ) -> Path: | |
| """ | |
| Parse model name to locally stored weights. It will download the model from the library only when in limited environment. | |
| Args: | |
| model_ref (str) : Model reference containing org_name/model_name such as 'meta-llama/Llama-2-7b-chat-hf'. | |
| revision (str): Model revision branch. Defaults to 'main'. | |
| Returns: | |
| str: path to model weights within model directory | |
| """ | |
| if os.getenv("HF_FROM_LIMITED_ENV") is None: | |
| logger.info("Regular environment detected.") | |
| return repo_id | |
| models_dir = HF_DEFAULT_HOME / "hub" | |
| cached_name = f'models--{repo_id.replace("/", "--")}' | |
| cached_dir = models_dir / cached_name | |
| if not cached_dir.exists(): | |
| snapshot_download(repo_id=repo_id, revision=revision, cache_dir=models_dir) | |
| snapshot_hash = (cached_dir / "refs" / revision).read_text() | |
| return cached_dir / "snapshots" / snapshot_hash |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment