Last active
December 27, 2022 21:52
-
-
Save shellward/7c7339e0fdf5ed56cbf90e690f309b60 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from tqdm.auto import tqdm\n", | |
| "from PIL import Image\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import skimage.measure as measure\n", | |
| "\n", | |
| "from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n", | |
| "from point_e.diffusion.sampler import PointCloudSampler\n", | |
| "from point_e.models.download import load_checkpoint\n", | |
| "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n", | |
| "from point_e.util.pc_to_mesh import marching_cubes_mesh\n", | |
| "from point_e.util.plotting import plot_point_cloud\n", | |
| "from point_e.util.point_cloud import PointCloud" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Set a prompt to condition on.\n", | |
| "prompts = ['desk', 'table', 'lamp', 'tv', 'alarm clock', 'bed', 'pillow']\n", | |
| "\n", | |
| "#also export mesh version (pc->sdf->marching cubes, will have many more verts!)\n", | |
| "meshes = True" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
| "\n", | |
| "print('creating base model...')\n", | |
| "base_name = 'base40M-textvec'\n", | |
| "base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n", | |
| "base_model.eval()\n", | |
| "base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n", | |
| "\n", | |
| "print('creating upsample model...')\n", | |
| "upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n", | |
| "upsampler_model.eval()\n", | |
| "upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n", | |
| "\n", | |
| "print('downloading base checkpoint...')\n", | |
| "base_model.load_state_dict(load_checkpoint(base_name, device))\n", | |
| "\n", | |
| "print('downloading upsampler checkpoint...')\n", | |
| "upsampler_model.load_state_dict(load_checkpoint('upsample', device))\n", | |
| "\n", | |
| "print('creating SDF model...')\n", | |
| "name = 'sdf'\n", | |
| "model = model_from_config(MODEL_CONFIGS[name], device)\n", | |
| "model.eval()\n", | |
| "\n", | |
| "print('loading SDF model...')\n", | |
| "model.load_state_dict(load_checkpoint(name, device))\n", | |
| "\n", | |
| "\n", | |
| "def run_point_e(prompt):\n", | |
| " sampler = PointCloudSampler(\n", | |
| " device=device,\n", | |
| " models=[base_model, upsampler_model],\n", | |
| " diffusions=[base_diffusion, upsampler_diffusion],\n", | |
| " num_points=[1024, 4096 - 1024],\n", | |
| " aux_channels=['R', 'G', 'B'],\n", | |
| " guidance_scale=[3.0, 0.0],\n", | |
| " model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all\n", | |
| ")\n", | |
| "\n", | |
| "# Produce a sample from the model.\n", | |
| " samples = None\n", | |
| " for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):\n", | |
| " samples = x\n", | |
| " \n", | |
| " pc = sampler.output_to_point_clouds(samples)[0]\n", | |
| " if meshes:\n", | |
| " mesh = marching_cubes_mesh(\n", | |
| " pc=pc,\n", | |
| " model=model,\n", | |
| " batch_size=4096,\n", | |
| " grid_size=128, # increase to 128 for resolution used in evals\n", | |
| " progress=True,\n", | |
| " )\n", | |
| "#save the mesh and the point cloud\n", | |
| " if meshes:\n", | |
| " with open(f'output/{prompt}.ply', 'wb') as f:\n", | |
| " mesh.write_ply(f)\n", | |
| " \n", | |
| " with open(f'output/{prompt}_pc.ply', 'wb') as f:\n", | |
| " pc.write_ply(f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "for prompt in prompts:\n", | |
| " run_point_e(prompt)\n", | |
| " print(f'finished {prompt}')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.8" | |
| }, | |
| "orig_nbformat": 4, | |
| "vscode": { | |
| "interpreter": { | |
| "hash": "6a57af3429fe39f1db95ae52500f0dc20b0d5b033a19c06462446e07082bc71c" | |
| } | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Author
thanks! :)
Any help making a version of this, but for image2pointcloud? I want to convert an image sequence. I tried it myself but just getting errors.
Author
https://gist.github.com/shellward/96126d05b2ddf2688f677731258086d1 <--this should do the trick, but results are 🤷♂️
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Probably need to also add:
import skimage.measure as measure