Skip to content

Instantly share code, notes, and snippets.

@shellward
Last active December 27, 2022 21:52
Show Gist options
  • Select an option

  • Save shellward/7c7339e0fdf5ed56cbf90e690f309b60 to your computer and use it in GitHub Desktop.

Select an option

Save shellward/7c7339e0fdf5ed56cbf90e690f309b60 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tqdm.auto import tqdm\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"import skimage.measure as measure\n",
"\n",
"from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n",
"from point_e.diffusion.sampler import PointCloudSampler\n",
"from point_e.models.download import load_checkpoint\n",
"from point_e.models.configs import MODEL_CONFIGS, model_from_config\n",
"from point_e.util.pc_to_mesh import marching_cubes_mesh\n",
"from point_e.util.plotting import plot_point_cloud\n",
"from point_e.util.point_cloud import PointCloud"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Set a prompt to condition on.\n",
"prompts = ['desk', 'table', 'lamp', 'tv', 'alarm clock', 'bed', 'pillow']\n",
"\n",
"#also export mesh version (pc->sdf->marching cubes, will have many more verts!)\n",
"meshes = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print('creating base model...')\n",
"base_name = 'base40M-textvec'\n",
"base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n",
"base_model.eval()\n",
"base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n",
"\n",
"print('creating upsample model...')\n",
"upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n",
"upsampler_model.eval()\n",
"upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n",
"\n",
"print('downloading base checkpoint...')\n",
"base_model.load_state_dict(load_checkpoint(base_name, device))\n",
"\n",
"print('downloading upsampler checkpoint...')\n",
"upsampler_model.load_state_dict(load_checkpoint('upsample', device))\n",
"\n",
"print('creating SDF model...')\n",
"name = 'sdf'\n",
"model = model_from_config(MODEL_CONFIGS[name], device)\n",
"model.eval()\n",
"\n",
"print('loading SDF model...')\n",
"model.load_state_dict(load_checkpoint(name, device))\n",
"\n",
"\n",
"def run_point_e(prompt):\n",
" sampler = PointCloudSampler(\n",
" device=device,\n",
" models=[base_model, upsampler_model],\n",
" diffusions=[base_diffusion, upsampler_diffusion],\n",
" num_points=[1024, 4096 - 1024],\n",
" aux_channels=['R', 'G', 'B'],\n",
" guidance_scale=[3.0, 0.0],\n",
" model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all\n",
")\n",
"\n",
"# Produce a sample from the model.\n",
" samples = None\n",
" for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):\n",
" samples = x\n",
" \n",
" pc = sampler.output_to_point_clouds(samples)[0]\n",
" if meshes:\n",
" mesh = marching_cubes_mesh(\n",
" pc=pc,\n",
" model=model,\n",
" batch_size=4096,\n",
" grid_size=128, # increase to 128 for resolution used in evals\n",
" progress=True,\n",
" )\n",
"#save the mesh and the point cloud\n",
" if meshes:\n",
" with open(f'output/{prompt}.ply', 'wb') as f:\n",
" mesh.write_ply(f)\n",
" \n",
" with open(f'output/{prompt}_pc.ply', 'wb') as f:\n",
" pc.write_ply(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for prompt in prompts:\n",
" run_point_e(prompt)\n",
" print(f'finished {prompt}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "6a57af3429fe39f1db95ae52500f0dc20b0d5b033a19c06462446e07082bc71c"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@enzyme69
Copy link

Probably need to also add:
import skimage.measure as measure

@shellward
Copy link
Author

thanks! :)

@styletransfer
Copy link

Any help making a version of this, but for image2pointcloud? I want to convert an image sequence. I tried it myself but just getting errors.

@shellward
Copy link
Author

https://gist.github.com/shellward/96126d05b2ddf2688f677731258086d1 <--this should do the trick, but results are 🤷‍♂️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment