Created
December 27, 2022 21:49
-
-
Save shellward/96126d05b2ddf2688f677731258086d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "aFugjhROU9AH" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "!git clone https://github.com/openai/point-e.git\n", | |
| "%cd point-e\n", | |
| "%pip install -e .\n", | |
| "from PIL import Image\n", | |
| "import torch\n", | |
| "import os\n", | |
| "import random\n", | |
| "from tqdm.auto import tqdm\n", | |
| "\n", | |
| "from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n", | |
| "from point_e.diffusion.sampler import PointCloudSampler\n", | |
| "from point_e.models.download import load_checkpoint\n", | |
| "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n", | |
| "from point_e.util.plotting import plot_point_cloud\n", | |
| "from point_e.util.pc_to_mesh import marching_cubes_mesh\n", | |
| "from point_e.util.point_cloud import PointCloud\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import skimage.measure as measure\n", | |
| "\n", | |
| "# Uncomment if you want to load from Google Drive\n", | |
| "# from google.colab import drive\n", | |
| "# drive.mount('/content/drive')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "vzNiB7-eVUhO" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
| "\n", | |
| "print('creating base model...')\n", | |
| "base_name = 'base1B' # use base300M or base1B for better results\n", | |
| "base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n", | |
| "base_model.eval()\n", | |
| "base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n", | |
| "\n", | |
| "print('creating upsample model...')\n", | |
| "upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n", | |
| "upsampler_model.eval()\n", | |
| "upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n", | |
| "\n", | |
| "print('downloading base checkpoint...')\n", | |
| "base_model.load_state_dict(load_checkpoint(base_name, device))\n", | |
| "\n", | |
| "print('downloading upsampler checkpoint...')\n", | |
| "upsampler_model.load_state_dict(load_checkpoint('upsample', device))\n", | |
| "\n", | |
| "print('creating SDF model...')\n", | |
| "name = 'sdf'\n", | |
| "model = model_from_config(MODEL_CONFIGS[name], device)\n", | |
| "model.eval()\n", | |
| "\n", | |
| "print('loading SDF model...')\n", | |
| "model.load_state_dict(load_checkpoint(name, device))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "6uC-qHA0Vgqi" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def run_pointe(image_path, image_name):\n", | |
| " sampler = PointCloudSampler(\n", | |
| " device=device,\n", | |
| " models=[base_model, upsampler_model],\n", | |
| " diffusions=[base_diffusion, upsampler_diffusion],\n", | |
| " num_points=[1024, 4096 - 1024],\n", | |
| " aux_channels=['R', 'G', 'B'],\n", | |
| " guidance_scale=[3.0, 3.0],\n", | |
| ")\n", | |
| "# Load an image to condition on.\n", | |
| " img = Image.open(f\"{image_path}\")\n", | |
| "\n", | |
| "# #resize image to 256x256\n", | |
| "# img = img.resize((256, 256))\n", | |
| "\n", | |
| "\n", | |
| "# Produce a sample from the model.\n", | |
| " samples = None\n", | |
| " for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):\n", | |
| " samples = x\n", | |
| "\n", | |
| "\n", | |
| " pc = sampler.output_to_point_clouds(samples)[0]\n", | |
| " fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))\n", | |
| " mesh = marching_cubes_mesh(\n", | |
| " pc=pc,\n", | |
| " model=model,\n", | |
| " batch_size=4096,\n", | |
| " grid_size=128, # increase to 128 for resolution used in evals\n", | |
| " progress=True,\n", | |
| " )\n", | |
| "\n", | |
| " if not os.path.exists('output'):\n", | |
| " os.makedirs('output')\n", | |
| "\n", | |
| " with open(f'output/{image_name}.ply', 'wb') as f:\n", | |
| " mesh.write_ply(f)\n", | |
| " \n", | |
| " with open(f'output/{image_name}_pc.ply', 'wb') as f:\n", | |
| " pc.write_ply(f)\n", | |
| "\n", | |
| "# Uncomment if you want to run on a random order of images in a folder\n", | |
| "# def batch_folder_random_order(folder_path):\n", | |
| "# images = os.listdir(folder_path)\n", | |
| "# random.shuffle(images)\n", | |
| "# for image in images:\n", | |
| "# if image.endswith(\".jpg\") or image.endswith(\".png\"):\n", | |
| "# run_pointe(f\"{folder_path}/{image}\", image)\n", | |
| "# print(f\"finished {image}\")\n", | |
| "# else:\n", | |
| "# continue\n", | |
| "\n", | |
| "def batch_folder(folder_path):\n", | |
| " images = os.listdir(folder_path)\n", | |
| " for image in images:\n", | |
| " if image.endswith(\".jpg\") or image.endswith(\".png\"):\n", | |
| " run_pointe(f\"{folder_path}/{image}\", image)\n", | |
| " print(f\"finished {image}\")\n", | |
| " else:\n", | |
| " continue\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "## Here is where you will set the directory of your images\n", | |
| "## You can also uncomment the random order function above and comment out the normal function\n", | |
| "## if you want to run on a random order of images in a folder\n", | |
| "batch_folder(\"/content/drive/MyDrive/Colab Notebooks/point-e/images\") # change to your directory\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "FMRmzEhFLxbK", | |
| "outputId": "1ded2f97-4769-4065-e351-88e183f657ab" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "## Stanford Cars Dataset\n", | |
| "## This was largely unsuccessful, I wouldn't suggest using this.\n", | |
| "\n", | |
| "\n", | |
| "# def batch_stanford_cars(folder_path):\n", | |
| "# # need to traverse the folder structure, choosing one image from each folder to run point-e on\n", | |
| "# for folder in os.listdir(folder_path):\n", | |
| "\n", | |
| "# #choose one image from each folder\n", | |
| "# image = random.choice(os.listdir(f\"{folder_path}/{folder}\"))\n", | |
| "# if image.endswith(\".jpg\") or image.endswith(\".png\"):\n", | |
| "# run_pointe(f\"{folder_path}/{folder}/{image}\", image)\n", | |
| "# print(f\"finished {image}\")\n", | |
| "# else:\n", | |
| "# continue\n", | |
| " \n", | |
| "# !kaggle datasets download -d jutrera/stanford-car-dataset-by-classes-folder\n", | |
| "# !unzip stanford-car-dataset-by-classes-folder.zip\n", | |
| "# batch_stanford_cars('car_data/car_data/train')\n", | |
| "# !zip -r /content/output.zip /content/point-e/output" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "accelerator": "GPU", | |
| "colab": { | |
| "provenance": [] | |
| }, | |
| "gpuClass": "premium", | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great, thanks.