Skip to content

Instantly share code, notes, and snippets.

@omnp
Created November 23, 2024 11:45
Show Gist options
  • Select an option

  • Save omnp/f117bccd5737aed8e15019fbfacaece4 to your computer and use it in GitHub Desktop.

Select an option

Save omnp/f117bccd5737aed8e15019fbfacaece4 to your computer and use it in GitHub Desktop.
PyTorch-based implementation of the final scene as seen in Ray Tracing in One Weekend — The Book Series, Ray Tracing the Next Week (as of August 2022)
import torch, torch.distributions as distributions
import numpy
import math
import random
DTYPE=torch.float32
device = 'cuda'
pi = None
inf = None
uniform = None
normal = None
half_normal = None
def deg_to_rad(angle):
return angle * pi / 180.
def scalar(x):
return torch.tensor(x, dtype=DTYPE, device=device)
def vector(xs):
return torch.tensor(xs, dtype=DTYPE, device=device)
def dot(x, y, dim = -1):
return torch.sum(x * y, dim = dim)
def cross(u, v):
w = torch.zeros_like(u)
w[...,0] = u[...,1] * v[...,2] - u[...,2] * v[...,1]
w[...,1] = u[...,2] * v[...,0] - u[...,0] * v[...,2]
w[...,2] = u[...,0] * v[...,1] - u[...,1] * v[...,0]
return w
def length(x, dim = -1):
return torch.sqrt(torch.sum(x * x, dim = dim))
def length2(x, dim = -1):
return torch.sum(x * x, dim = dim)
def unit(x):
l = length(x)
l = torch.stack((l,l,l), dim = -1)
return x / l
def random_unit_vector(s,dist=None):
global uniform
if dist is None:
dist = uniform
r = dist.sample(s[:2] + (2,))
r[...,0] *= 2.*pi
r[...,1] *= pi
p,q = r[...,0],r[...,1]
x = torch.sin(p) * torch.cos(q)
y = torch.sin(p) * torch.sin(q)
z = torch.cos(p)
w = torch.stack((x,y,z), dim = -1)
return w
def random_in_unit_sphere(s,dist=None):
global uniform
if dist is None:
dist = uniform
r = dist.sample(s[:2] + (3,))
r[...,0] *= 2.*pi
r[...,1] *= pi
p,q,r = r[...,0],r[...,1],r[...,2]
x = r * torch.sin(p) * torch.cos(q)
y = r * torch.sin(p) * torch.sin(q)
z = r * torch.cos(p)
w = torch.stack((x,y,z), dim = -1)
return w
def random_in_unit_disk(s,dist=None):
global uniform
if dist is None:
dist = uniform
r = dist.sample(s[:2] + (2,))
w = torch.zeros(s,dtype=DTYPE,device=device)
r[...,0] *= 2.*pi
p,r = r[...,0],r[...,1]
x = r * torch.cos(p)
y = r * torch.sin(p)
w[...,0] = x
w[...,1] = y
return w
def near_zero(x):
s = 1e-8
return (abs(x[...,0]) < s) & (abs(x[...,1]) < s) & (abs(x[...,2]) < s)
color = vector
point = vector
class Texture:
def value(self, uv, p) -> color:
pass
class Material:
index = 0
instances = {}
def __init__(self) -> None:
Material.index += 1
self.index = Material.index
Material.instances[self.index] = self
def scatter(self, ray, rec, cc_index):
pass
def emitted(self, uv, p, cc_index):
return color((0,0,0))[cc_index]
class SolidColor(Texture):
def __init__(self, c: color) -> None:
self.color = c
def value(self, uv, p) -> color:
return self.color
class CheckerTexture(Texture):
def __init__(self, even: Texture, odd: Texture) -> None:
self.even = even
self.odd = odd
def value(self, uv, p) -> color:
sines = torch.sin(10*p[...,0])*torch.sin(10*p[...,1])*torch.sin(10*p[...,2])
c = torch.zeros(uv.size()[:2]+(3,),dtype=DTYPE,device=device)
c_odd = self.odd.value(uv, p)
c_even = self.even.value(uv, p)
#print(c.size(),c_odd.size(), c_even.size())
return torch.where(torch.stack(3*(sines < 0,),-1), c_odd, c_even)
class Perlin:
point_count = 256
def __init__(self) -> None:
self.ranvec = random_unit_vector((Perlin.point_count,),dist=uniform)
self.perm_x = perlin_generate_perm()
self.perm_y = perlin_generate_perm()
self.perm_z = perlin_generate_perm()
def noise(self, p):
uvw = p - p.floor()
ijk = p.floor().long()
didjdk = torch.arange(2,dtype=torch.long,device=device)
didjdk = torch.cartesian_prod(*3*(didjdk,))
c = torch.zeros(p.size()[:2]+didjdk.size(),dtype=DTYPE,device=device)
for i,idx in enumerate(didjdk[...,:]):
c[:,:,i,:] = self.ranvec[self.perm_x[(ijk[...,0]+idx[0]) & 255] ^ self.perm_y[(ijk[...,1]+idx[1]) & 255] ^ self.perm_z[(ijk[...,2]+idx[2]) & 255]]
return perlin_interp(c, uvw)
def turb(self, p, depth=7):
accum = torch.zeros(p.size()[:2],dtype=DTYPE,device=device)
temp_p = p.clone()
weight = 1.0
for i in range(depth):
accum += weight * self.noise(temp_p)
weight *= 0.5
temp_p *= 2
return torch.abs(accum)
def perlin_generate_perm():
p = list(range(Perlin.point_count))
random.shuffle(p)
p = torch.tensor(p,dtype=torch.long,device=device)
return p
def trilinear_interp(c, uvw):
ijk = torch.arange(2)
ijk = torch.cartesian_prod(*3*(ijk,))
accum = torch.zeros(c.size()[:2],dtype=DTYPE,device=device)
for _,idx in enumerate(ijk[...,:]):
accum += (idx*uvw + (1-idx)*(1-uvw))*c[idx[0],idx[1],idx[2]]
return accum
def perlin_interp(c, uvw):
uuvvww = uvw*uvw*(3-2*uvw)
ijk = torch.arange(2,dtype=torch.long,device=device)
ijk = torch.cartesian_prod(*3*(ijk,))
accum = torch.zeros(c.size()[:2],dtype=DTYPE,device=device)
for i,idx in enumerate(ijk[...,:]):
weight_v = uvw-idx
accum += (idx[0]*uuvvww[...,0] + (1-idx[0])*(1-uuvvww[...,0])) * \
(idx[1]*uuvvww[...,1] + (1-idx[1])*(1-uuvvww[...,1])) * \
(idx[2]*uuvvww[...,2] + (1-idx[2])*(1-uuvvww[...,2])) * \
dot(c[:,:,i,:], weight_v)
return accum
class NoiseTexture(Texture):
def __init__(self, scale: Texture) -> None:
#super().__init__()
self.noise = Perlin()
self.scale = scale
def value(self, uv, p) -> color:
#n = self.noise.noise(p * self.scale.value(uv, p)).unsqueeze(-1)
#return color((1,1,1)) * 0.5 * (1.0 + n)
#n = self.noise.turb(p * self.scale.value(uv, p)).unsqueeze(-1)
#return color((1,1,1)) * n
t = self.noise.turb(p)
n = torch.sin(self.scale.value(uv, p)[...,2]*p[...,2] + 10*t)
return color((1,1,1)) * 0.5 * (1 + n.unsqueeze(-1))
class ImageTexture(Texture):
def __init__(self, image) -> None:
self.image = image
def value(self, uv, p) -> color:
uv = clamp(uv)
uv[...,1] = 1.0 - uv[...,1]
wh = torch.tensor(self.image.size()[:2], dtype=torch.long, device=device)
ij = (uv * wh).long()
ij[...,0] = clamp(ij[...,0], 0, wh[0]-1)
ij[...,1] = clamp(ij[...,1], 0, wh[1]-1)
color_scale = 1.0/255.0
pixels = self.image[ij[:,:,0],ij[:,:,1],:]
return color_scale * pixels
def clamp(xs, low=0.0, high=1.0):
return torch.where(xs < low, torch.full_like(xs, low), torch.where(xs > high, torch.full_like(xs, high), xs))
class Record:
def __init__(self, ray):
self.p = torch.zeros_like(ray.origin).to(device)
self.normal = unit(torch.ones_like(ray.origin).to(device))
self.t = torch.ones_like(ray.origin[:,:,0]).to(device)
self.uv = torch.zeros_like(ray.origin[:,:,:2]).to(device)
self.front_face = torch.full(ray.origin.size()[:2], False).to(device)
#self.object = torch.zeros(ray.origin.size()[:2], dtype=torch.int).to(device)
self.target = torch.zeros(ray.origin.size()[:2], dtype=torch.int).to(device)
def set_face_normal(self, ray, outward_normal):
self.front_face = dot(ray.direction, outward_normal) < 0.
self.normal = torch.where(torch.stack(3*(self.front_face,), dim=-1), outward_normal, -outward_normal)
def update(self, mask, rec):
mask3 = torch.stack(3*(mask,), dim = -1)
self.p = torch.where(mask3, rec.p, self.p)
self.normal = torch.where(mask3, rec.normal, self.normal)
self.t = torch.where(mask, rec.t, self.t)
self.uv = torch.where(torch.stack(2*(mask,),-1), rec.uv, self.uv)
self.front_face = torch.where(mask, rec.front_face, self.front_face)
self.target = torch.where(mask, rec.target, self.target)
def copy_to(self, r):
r.p = self.p.clone()
r.normal = self.normal.clone()
r.t = self.t.clone()
r.uv = self.uv.clone()
r.front_face = self.front_face.clone()
r.target = self.target.clone()
class aabb:
def __init__(self, minimum, maximum) -> None:
self.minimum = torch.floor(minimum)
self.maximum = torch.ceil(maximum)
def hit(self, ray, t_min, t_max):
#t_min, t_max = torch.min(t_min, t_max), torch.max(t_min, t_max)
#p0 = ray.origin + torch.stack(3*(t_min,),-1)*ray.direction
#p1 = ray.origin + torch.stack(3*(t_max,),-1)*ray.direction
#assert(not torch.any(p0.isnan()))
#assert(not torch.any(p1.isnan()))
#return (p0 <= self.minimum).logical_and(p1 >= self.maximum)
#mn = p0 <= self.minimum
#mx = p1 >= self.maximum
#print(torch.any(mn),torch.any(mx))
c = self.minimum - ray.origin
R = torch.zeros_like(ray.origin[...,0],dtype=torch.bool)
for i in [0,1,2]:
# Dimension i of u,v,w is fixed to 0 or 1.
for u in [0,1]:
m = self.maximum.clone()
m[...,i] *= u
m = m.repeat(ray.origin.size()[:2]).reshape(ray.origin.size())
A = torch.stack((ray.origin,-m,-c),-1)
R = R.logical_or(A.det() != 0)
return R
def surrounding_box(box0: aabb, box1: aabb) -> aabb:
return aabb(torch.minimum(box0.minimum, box1.minimum), torch.maximum(box0.maximum, box1.maximum))
class Object3:
def hit(self, ray, t_min, t_max, rec):
pass
def bounding_box(self, time0, time1):
pass
class Sphere(Object3):
def __init__(self, time0, time1, center0 = point([0.,0.,0.]), center1 = point([0.,0.,0.]), radius = 1., material=Material(), object_index=0):
self.center0 = center0
self.center1 = center1
self.radius = radius
self.material = material
self.object = object_index
self.time0 = time0
self.time1 = time1
def hit(self, ray, t_min, t_max, rec):
oc = ray.origin - self.center(ray.time)
a = dot(ray.direction, ray.direction)
half_b = dot(oc, ray.direction)
c = dot(oc,oc) - self.radius*self.radius
discriminant = half_b*half_b - a*c
sqrtd = torch.sqrt(torch.max(torch.zeros_like(discriminant),discriminant))
root1 = (-half_b - sqrtd) / a
root2 = (-half_b + sqrtd) / a
root = torch.where((root1 < t_min).logical_or(t_max < root1), root2, root1)
mask = (root >= t_min).logical_and(t_max >= root)
mask = (discriminant >= .0).logical_and(mask)
mask3 = torch.stack(3*(mask,), dim=-1)
rec.t = torch.where(mask, root, rec.t)
rec.p = torch.where(mask3, ray.at(torch.stack(3*(rec.t,), dim=-1)), rec.p)
outward_normal = (rec.p - self.center(ray.time)) / self.radius
rec.set_face_normal(ray, outward_normal)
uv = get_sphere_uv(outward_normal)
rec.uv = torch.where(torch.stack(2*(mask,),-1), uv, rec.uv)
#rec.object = torch.where(mask, self.object, rec.object)
rec.target = torch.where(mask, self.material.index, rec.target)
return mask
def bounding_box(self, time0, time1):
box0 = aabb(self.center(time0) - vector(3*(self.radius,)), self.center(time0) + vector(3*(self.radius,)))
box1 = aabb(self.center(time1) - vector(3*(self.radius,)), self.center(time1) + vector(3*(self.radius,)))
return surrounding_box(box0, box1)
def center(self, time):
if type(time) == torch.Tensor:
return torch.stack(3*((time - torch.full_like(time, self.time0)) / (self.time1 - self.time0),),-1) * (self.center1 - self.center0) + self.center0
return (time - self.time0) / (self.time1 - self.time0) * (self.center1 - self.center0) + self.center0
def get_sphere_uv(p):
theta = torch.acos(-p[...,1])
phi = torch.atan2(-p[...,2], p[...,0]) + pi
u = phi / (2*pi)
v = theta / pi
return torch.stack((u,v),-1)
class XY_Rect(Object3):
def __init__(self, x0, x1, y0, y1, k, material, object_index=0) -> None:
self.x0 = x0
self.x1 = x1
self.y0 = y0
self.y1 = y1
self.k = k
self.material = material
self.object = object_index
def hit(self, ray, t_min, t_max, rec):
t = (self.k - ray.origin[...,2]) / ray.direction[...,2]
x = ray.origin[...,0] + t*ray.direction[...,0]
y = ray.origin[...,1] + t*ray.direction[...,1]
mask = (t >= t_min).logical_and(t_max >= t)
mask &= ~((x < self.x0) | (x > self.x1) | (y < self.y0) | (y > self.y1))
t_rec = Record(ray)
t_rec.uv[...,0] = (x-self.x0)/(self.x1-self.x0)
t_rec.uv[...,1] = (y-self.y0)/(self.y1-self.y0)
t_rec.t = t
outward_normal = vector((0,0,1))
t_rec.set_face_normal(ray, outward_normal)
t_rec.target = self.material.index
t_rec.p = ray.at(torch.stack(3*(t,),-1))
rec.update(mask, t_rec)
return mask
def bounding_box(self, time0, time1):
return aabb(point((self.x0,self.y0,self.k-0.0001)),point((self.x1,self.y1,self.k+0.0001)))
class XZ_Rect(Object3):
def __init__(self, x0, x1, z0, z1, k, material, object_index=0) -> None:
self.x0 = x0
self.x1 = x1
self.z0 = z0
self.z1 = z1
self.k = k
self.material = material
self.object = object_index
def hit(self, ray, t_min, t_max, rec):
t = (self.k - ray.origin[...,1]) / ray.direction[...,1]
x = ray.origin[...,0] + t*ray.direction[...,0]
z = ray.origin[...,2] + t*ray.direction[...,2]
mask = (t >= t_min).logical_and(t_max >= t)
mask &= ~((x < self.x0) | (x > self.x1) | (z < self.z0) | (z > self.z1))
t_rec = Record(ray)
t_rec.uv[...,0] = (x-self.x0)/(self.x1-self.x0)
t_rec.uv[...,1] = (z-self.z0)/(self.z1-self.z0)
t_rec.t = t
outward_normal = vector((0,1,0))
t_rec.set_face_normal(ray, outward_normal)
t_rec.target = self.material.index
t_rec.p = ray.at(torch.stack(3*(t,),-1))
rec.update(mask, t_rec)
return mask
def bounding_box(self, time0, time1):
return aabb(point((self.x0,self.k-0.0001,self.z0)),point((self.x1,self.k+0.0001,self.z1)))
class YZ_Rect(Object3):
def __init__(self, y0, y1, z0, z1, k, material, object_index=0) -> None:
self.y0 = y0
self.y1 = y1
self.z0 = z0
self.z1 = z1
self.k = k
self.material = material
self.object = object_index
def hit(self, ray, t_min, t_max, rec):
t = (self.k - ray.origin[...,0]) / ray.direction[...,0]
y = ray.origin[...,1] + t*ray.direction[...,1]
z = ray.origin[...,2] + t*ray.direction[...,2]
mask = (t >= t_min).logical_and(t_max >= t)
mask &= ~((y < self.y0) | (y > self.y1) | (z < self.z0) | (z > self.z1))
t_rec = Record(ray)
t_rec.uv[...,0] = (y-self.y0)/(self.y1-self.y0)
t_rec.uv[...,1] = (z-self.z0)/(self.z1-self.z0)
t_rec.t = t
outward_normal = vector((1,0,0))
t_rec.set_face_normal(ray, outward_normal)
t_rec.target = self.material.index
t_rec.p = ray.at(torch.stack(3*(t,),-1))
rec.update(mask, t_rec)
return mask
def bounding_box(self, time0, time1):
return aabb(point((self.k-0.0001,self.y0,self.z0)),point((self.k+0.0001,self.y1,self.z1)))
class Box(Object3):
def __init__(self, p0, p1, material) -> None:
self.box_min = p0
self.box_max = p1
self.material = material
self.sides = Object3s()
self.sides.add(XY_Rect(p0[0], p1[0], p0[1], p1[1], p1[2], material))
self.sides.add(XY_Rect(p0[0], p1[0], p0[1], p1[1], p0[2], material))
self.sides.add(XZ_Rect(p0[0], p1[0], p0[2], p1[2], p1[1], material))
self.sides.add(XZ_Rect(p0[0], p1[0], p0[2], p1[2], p0[1], material))
self.sides.add(YZ_Rect(p0[1], p1[1], p0[2], p1[2], p1[0], material))
self.sides.add(YZ_Rect(p0[1], p1[1], p0[2], p1[2], p0[0], material))
def hit(self, ray, t_min, t_max, rec):
return self.sides.hit(ray, t_min, t_max, rec)
def bounding_box(self, time0, time1):
return aabb(self.box_min, self.box_max)
class Translate(Object3):
def __init__(self, p, offset) -> None:
super().__init__()
self.p = p
self.offset = offset
def hit(self, ray, t_min, t_max, rec):
ray = Ray(ray.origin - self.offset, ray.direction, ray.time)
h = self.p.hit(ray, t_min, t_max, rec)
t_rec = Record(ray)
rec.copy_to(t_rec)
t_rec.p += self.offset
t_rec.set_face_normal(ray, rec.normal)
rec.update(h, t_rec)
return h
def bounding_box(self, time0, time1):
b = self.p.bounding_box(time0, time1)
return aabb(b.minimum + self.offset, b.maximum + self.offset)
class Rotate_Y(Object3):
def __init__(self, p, angle) -> None:
super().__init__()
self.p = p
self.radians = angle / 360.0 * 2 * pi
self.sin_theta = math.sin(self.radians)
self.cos_theta = math.cos(self.radians)
self.box = p.bounding_box(0,1)
mn = point((inf,inf,inf))
mx = point((-inf,-inf,-inf))
ijk = torch.arange(2)
ijk = torch.cartesian_prod(*3*(ijk,))
for _,idx in enumerate(ijk[...,:]):
x = idx[0]*self.box.maximum[0] + (1-idx[0])*self.box.minimum[0]
y = idx[1]*self.box.maximum[1] + (1-idx[1])*self.box.minimum[1]
z = idx[2]*self.box.maximum[2] + (1-idx[2])*self.box.minimum[2]
newx = self.cos_theta*x + self.sin_theta*z
newz = -self.sin_theta*x + self.cos_theta*z
tester = vector((newx,y,newz))
mn = torch.minimum(mn, tester)
mx = torch.maximum(mx, tester)
self.box = aabb(mn, mx)
def hit(self, ray, t_min, t_max, rec):
origin = ray.origin.clone()
direction = ray.direction.clone()
origin[...,0] = self.cos_theta*ray.origin[...,0] - self.sin_theta*ray.origin[...,2]
origin[...,2] = self.sin_theta*ray.origin[...,0] + self.cos_theta*ray.origin[...,2]
direction[...,0] = self.cos_theta*ray.direction[...,0] - self.sin_theta*ray.direction[...,2]
direction[...,2] = self.sin_theta*ray.direction[...,0] + self.cos_theta*ray.direction[...,2]
ray = Ray(origin, direction, ray.time)
h = self.p.hit(ray, t_min, t_max, rec)
t_rec = Record(ray)
rec.copy_to(t_rec)
t_rec.p[...,0] = self.cos_theta*rec.p[...,0] + self.sin_theta*rec.p[...,2]
t_rec.p[...,2] = -self.sin_theta*rec.p[...,0] + self.cos_theta*rec.p[...,2]
t_rec.normal[...,0] = self.cos_theta*rec.normal[...,0] + self.sin_theta*rec.normal[...,2]
t_rec.normal[...,2] = -self.sin_theta*rec.normal[...,0] + self.cos_theta*rec.normal[...,2]
t_rec.set_face_normal(ray, t_rec.normal)
rec.update(h, t_rec)
return h
def bounding_box(self, time0, time1):
return self.box
class ConstantMedium(Object3):
def __init__(self, b, d, a: Texture) -> None:
super().__init__()
self.boundary = b
self.neg_inv_density = -1/d
self.phase_funtion = Isotropic(a)
def hit(self, ray, t_min, t_max, rec):
rec1 = Record(ray)
rec2 = Record(ray)
h = self.boundary.hit(ray, -inf, inf, rec1)
g = self.boundary.hit(ray, rec1.t+0.0001, inf, rec2)
rec1.t = torch.where(rec1.t < t_min, t_min, rec1.t)
rec2.t = torch.where(rec2.t > t_max, t_max, rec2.t)
i = rec1.t < rec2.t
rec1.t = torch.where(rec1.t < 0, 0, rec1.t)
ray_length = length(ray.direction)
distance_inside_boundary = (rec2.t - rec1.t) * ray_length
hit_distance = self.neg_inv_density * torch.log(uniform.sample(distance_inside_boundary.size()))
j = hit_distance <= distance_inside_boundary
t_rec = Record(ray)
rec.copy_to(t_rec)
t_rec.t = rec1.t + hit_distance / ray_length
t_rec.p = ray.at(torch.stack(3*(t_rec.t,),-1))
t_rec.normal = vector((1,0,0))
t_rec.front_face = True
t_rec.target = self.phase_funtion.index
mask = ((h & g) & i) & j
rec.update(mask, t_rec)
return mask
def bounding_box(self, time0, time1):
return self.boundary.bounding_box(time0, time1)
class Ray:
def __init__(self, origin, direction, time):
self.origin = origin
self.direction = direction
self.time = time
def at(self, t):
return self.origin + t*self.direction
class Object3s(Object3):
def __init__(self):
self.objects = []
self.object_index = 0
def add(self, object3):
self.object_index += 1
object3.object = self.object_index
self.objects.append(object3)
def items(self):
return len(self.objects)
def get(self, index):
for obj in self.objects:
if obj.object == index:
return obj
def get_at(self, index):
return self.objects[index]
def hit(self, ray, t_min, t_max, rec):
temp_rec = Record(ray)
hit_anything = torch.full_like(ray.origin[:,:,0], False, dtype=torch.bool)
closest_so_far = t_max
for object3 in self.objects:
h = object3.hit(ray, t_min, closest_so_far, temp_rec)
hit_anything = hit_anything.logical_or(h)
closest_so_far = torch.where(h, temp_rec.t, closest_so_far)
rec.update(h, temp_rec)
return hit_anything
def bounding_box(self, time0, time1):
box = None#aabb(point(3*(-inf,)), point(3*(inf,)))
for object3 in self.objects:
object_box = object3.bounding_box(time0, time1)
if box is None:
box = object_box
else:
box = surrounding_box(box, object_box)
return box
def box_compare(a, b, axis):
box_a = a.bounding_box(0,0)
box_b = b.bounding_box(0,0)
return box_a.minimum[axis] < box_b.minimum[axis]
class bvh_node(Object3):
def __init__(self, objects: Object3s = None, time0 = 0, time1 = 0, start=None, end=None) -> None:
self.objects = objects
self.time0 = time0
self.time1 = time1
self.start = 0 if start is None else start
self.end = self.objects.items() if end is None else end
self.box = None
self.left: Object3 = None
self.right: Object3 = None
axis = torch.randint(low=0,high=3,size=(1,))
span = self.end - self.start
if span == 1:
self.left = self.right = self.objects.get_at(self.start)
elif span == 2:
self.left = self.objects.get_at(self.start)
self.right = self.objects.get_at(self.start+1)
if not box_compare(self.left, self.right, axis):
self.left, self.right = self.right, self.left
else:
objects_ = list(self.objects.objects)
objects_[self.start:self.end] = sorted(objects_[self.start:self.end],key=lambda x: x.bounding_box(0,0).minimum[axis])
mid = self.start + span//2
obj1 = Object3s()
obj1.objects = objects_
obj1.object_index = self.objects.object_index
obj2 = Object3s()
obj2.objects = objects_
obj2.object_index = self.objects.object_index
self.left = bvh_node(obj1, time0, time1, self.start, mid)
self.right = bvh_node(obj2, time0, time1, mid, self.end)
box_left = self.left.bounding_box(time0, time1)
box_right = self.right.bounding_box(time0, time1)
self.box = surrounding_box(box_left, box_right)
def hit(self, ray, t_min, t_max, rec):
h = self.box.hit(ray, t_min, t_max)
if not torch.any(h):
return h
temp_rec = Record(ray)
hit_anything = torch.full_like(ray.origin[:,:,0], False, dtype=torch.bool)
closest_so_far = t_max
for object3 in (self.left, self.right):
h = object3.hit(ray, t_min, closest_so_far, temp_rec)
hit_anything = hit_anything.logical_or(h)
closest_so_far = torch.where(h, temp_rec.t, closest_so_far)
rec.update(h, temp_rec)
return hit_anything
def bounding_box(self, time0, time1):
return self.box
def items(self):
return self.objects.items()
def get(self, i):
return self.objects.get(i)
def get_at(self, index):
return self.objects.get_at(index)
class Lambertian(Material):
def __init__(self, albedo: Texture):
super().__init__()
self.albedo = albedo
def scatter(self, ray, rec, cc_index):
scatter_direction = rec.normal + random_unit_vector(rec.normal.size())
scatter_direction = torch.where(torch.stack(3*(near_zero(scatter_direction),),dim=-1),rec.normal,scatter_direction)
scattered = Ray(rec.p, scatter_direction, ray.time)
attenuation = self.albedo.value(rec.uv, rec.p)[...,cc_index]
return True, scattered, attenuation
def reflect(v, n):
t = 2.0*dot(v,n)
t = torch.stack(3*(t,),dim=-1)
n = t * n
return v - n
class Metal(Material):
def __init__(self, albedo: Texture, fuzz: Texture):
super().__init__()
self.albedo = albedo
self.fuzz = fuzz#fuzz if fuzz < 1.0 else 1.0
def scatter(self, ray, rec, cc_index):
reflected = reflect(unit(ray.direction), rec.normal)
fuzz = self.fuzz.value(rec.uv, rec.p)[...,cc_index] * random_in_unit_sphere(rec.normal.size())
reflected_fuzz = reflected + fuzz
n = near_zero(reflected_fuzz).unsqueeze(-1)
reflected = torch.where(n, reflected, reflected_fuzz)
scattered = Ray(rec.p, reflected, ray.time)
attenuation = self.albedo.value(rec.uv, rec.p)[...,cc_index]
return dot(scattered.direction, rec.normal) > 0, scattered, attenuation
def refract(uv, n, etai_over_etat):
cos_theta = torch.minimum(dot(-uv, n), torch.tensor([1.0], dtype=DTYPE, device=device))
etai_over_etat = torch.stack(3*(etai_over_etat,),-1)
cos_theta = torch.stack(3*(cos_theta,),-1)
r_out_perp = etai_over_etat * (uv + cos_theta*n)
x = -torch.sqrt(torch.abs(1.0 - length2(r_out_perp)))
r_out_parallel = torch.stack(3*(x,),-1)*n
return r_out_perp + r_out_parallel
def reflectance(cosine, ref_index):
r0 = (1-ref_index) / (1 + ref_index)
r0 *= r0
return r0 + (1-r0)*torch.pow(1-cosine,5)
class Dielectric(Material):
def __init__(self, ior: Texture):
super().__init__()
self.ior = ior
def scatter(self, ray, rec, cc_index):
attenuation = color((1.0,1.0,1.0))[...,cc_index]
refraction_ratio = torch.where(rec.front_face, 1.0/self.ior.value(rec.uv, rec.p)[...,cc_index], self.ior.value(rec.uv, rec.p)[...,cc_index])
unit_direction = unit(ray.direction)
cos_theta = torch.minimum(dot(-unit_direction, rec.normal), torch.tensor([1.0],dtype=DTYPE,device=device))
sin_theta = torch.sqrt(1.0 - cos_theta*cos_theta)
cannot_refract = refraction_ratio * sin_theta > 1.0
cannot_refract = cannot_refract | (reflectance(cos_theta, refraction_ratio) > uniform.sample(refraction_ratio.size()))#.to(device))
direction = torch.where(cannot_refract.unsqueeze(-1), reflect(unit_direction, rec.normal), refract(unit_direction, rec.normal, refraction_ratio))
scattered = Ray(rec.p, direction, ray.time)
return True, scattered, attenuation
class MetallicDielectric(Metal):
def __init__(self, albedo: Texture, fuzz: Texture, ior: Texture):
super().__init__(albedo, fuzz)
self.ior = ior
def scatter(self, ray, rec, cc_index):
attenuation = color((1.0,1.0,1.0))[...,cc_index]
refraction_ratio = torch.where(rec.front_face, 1.0/self.ior.value(rec.uv, rec.p)[...,cc_index], self.ior.value(rec.uv, rec.p)[...,cc_index])
unit_direction = unit(ray.direction)
cos_theta = torch.minimum(dot(-unit_direction, rec.normal), torch.tensor([1.0],dtype=DTYPE,device=device))
sin_theta = torch.sqrt(1.0 - cos_theta*cos_theta)
cannot_refract = refraction_ratio * sin_theta > 1.0
cannot_refract = cannot_refract | (reflectance(cos_theta, refraction_ratio) > uniform.sample(refraction_ratio.size()))#.to(device))
_, scd, att = super().scatter(ray, rec, cc_index)
direction = torch.where(cannot_refract.unsqueeze(-1), scd.direction, refract(unit_direction, rec.normal, refraction_ratio))
attenuation = torch.where(cannot_refract, att, attenuation)
scattered = Ray(rec.p, direction, ray.time)
return True, scattered, attenuation
class FuzzyDielectric(MetallicDielectric):
def __init__(self, fuzz: Texture, ior: Texture):
super().__init__(SolidColor(color([1.0,1.0,1.0])), fuzz, ior)
def scatter(self, ray, rec, cc_index):
return super().scatter(ray, rec, cc_index)
class DielectricMetallic(Metal):
def __init__(self, albedo: Texture, fuzz: Texture, ior: Texture):
super().__init__(albedo, fuzz)
self.ior = ior
def scatter(self, ray, rec, cc_index):
attenuation = color((1.0,1.0,1.0))[...,cc_index]
refraction_ratio = torch.where(rec.front_face, 1.0/self.ior.value(rec.uv, rec.p)[...,cc_index], self.ior.value(rec.uv, rec.p)[...,cc_index])
unit_direction = unit(ray.direction)
cos_theta = torch.minimum(dot(-unit_direction, rec.normal), torch.tensor([1.0],dtype=DTYPE,device=device))
sin_theta = torch.sqrt(1.0 - cos_theta*cos_theta)
cannot_refract = refraction_ratio * sin_theta > 1.0
cannot_refract = cannot_refract | (reflectance(cos_theta, refraction_ratio) > uniform.sample(refraction_ratio.size()))#.to(device))
true, scd, att = super().scatter(ray, rec, cc_index)
direction = torch.where((true | cannot_refract).unsqueeze(-1), scd.direction, refract(unit_direction, rec.normal, refraction_ratio))
attenuation = torch.where((true | cannot_refract), att, attenuation)
scattered = Ray(rec.p, direction, ray.time)
return True, scattered, attenuation
class DiffuseLight(Material):
def __init__(self, emit) -> None:
super().__init__()
self.emit = emit
def scatter(self, ray, rec, cc_index):
return None#False,ray,color((0,0,0))[...,cc_index]
def emitted(self, ray, rec, cc_index):
return self.emit.value(rec.uv, rec.p)[...,cc_index]
class Isotropic(Material):
def __init__(self, albedo: Texture) -> None:
super().__init__()
self.albedo = albedo
def scatter(self, ray, rec, cc_index):
scattered = Ray(rec.p, random_in_unit_sphere(ray.direction.size(),dist=uniform), ray.time)
attenuation = self.albedo.value(rec.uv, rec.p)[...,cc_index]
return True,scattered,attenuation
class Camera:
def __init__(self, lookfrom, lookat, viewup, vfov, aspect, aperture, focus, blur=uniform, time0 = 0, time1 = 0):
self.lookfrom = lookfrom
self.lookat = lookat
self.viewup = viewup
self.vfov = vfov
self.aspect = aspect
self.aperture = aperture
self.focus = focus
self.blur = blur
self.time0 = time0
self.time1 = time1
theta = vfov / 360.0 * 2.0*math.pi
h = math.tan(theta/2.0)
viewport_height = 2.0 * h
viewport_width = aspect * viewport_height
self.w = unit(lookfrom - lookat)
self.u = unit(cross(viewup,self.w))
self.v = cross(self.w,self.u)
self.origin = lookfrom.to(device)
self.horizontal = focus*viewport_width*self.u.to(device)
self.vertical = focus*viewport_height*self.v.to(device)
self.hv = self.horizontal + self.vertical
self.lower_left = self.origin - self.horizontal/2 - self.vertical/2 - focus*self.w
self.radius = aperture / 2
self.cache = None
def get_ray(self, uvw):
if self.cache is None:
self.cache = {}
self.cache['origin'] = torch.zeros_like(uvw)
self.cache['origin'][...] = self.origin
self.cache['lower_left'] = torch.zeros_like(uvw)
self.cache['lower_left'][...] = self.lower_left
self.cache['hv'] = torch.zeros_like(uvw)
self.cache['hv'][...] = self.hv
self.cache['u'] = torch.zeros_like(uvw)
self.cache['u'][...] = self.u
self.cache['v'] = torch.zeros_like(uvw)
self.cache['v'][...] = self.v
self.cache['w'] = torch.zeros_like(uvw)
self.cache['w'][...] = self.w
origin = self.cache['origin']
rd = self.radius * random_in_unit_disk(self.cache['origin'].size(),self.blur)
offset = self.cache['u'] * rd[...,0:1] + self.cache['v'] * rd[...,1:2] + self.cache['w'] * rd[...,2:3]
direction = self.cache['lower_left'] + uvw * self.cache['hv'] - origin
return Ray(origin + offset, direction - offset, uniform.sample(origin.size()[:2])*(self.time1-self.time0) + self.time0)
@torch.no_grad()
def main():
print('Hello!')
import random
import sys
import PIL
from PIL import Image,ImageTk
import tkinter
import time
global device, pi, inf, DTYPE
global uniform, normal, half_normal
# Image
width = 2048
height = 1024
# Render
tile_size = 512#128#1024#64
fn = sys.argv[1]
width = int(sys.argv[2])
height = int(sys.argv[3])
aspect = width / height
samples_per_pixel = int(sys.argv[4])
max_depth = int(sys.argv[5])
tile_size = int(sys.argv[6])
if len(sys.argv) > 7:
device = sys.argv[7]
if len(sys.argv) > 8:
dt = sys.argv[8]
DTYPE = {'16': torch.bfloat16 if device == 'cpu' else torch.float16, '32':torch.float32, '64':torch.float64}[dt]
autocast = False
if len(sys.argv) > 9:
autocast = bool(sys.argv[9])
uniform = distributions.uniform.Uniform(torch.tensor(0,dtype=DTYPE,device=device),torch.tensor(1,dtype=DTYPE,device=device))
normal = distributions.normal.Normal(torch.tensor(0,dtype=DTYPE,device=device),torch.tensor(0.5,dtype=DTYPE,device=device))
half_normal = distributions.half_normal.HalfNormal(torch.tensor(0.5,dtype=DTYPE,device=device))
pi = math.pi
inf = torch.tensor(math.inf, dtype=DTYPE, device=device)
# World
##material_ground = Lambertian(CheckerTexture(SolidColor(color([0.2,0.3,0.1])), SolidColor(color([0.9,0.9,0.9]))))#Metal(color([0.8,0.8,0.0]), 3.0)#Lambertian(color([0.8,0.8,0.0]))#Metal(color([0.8,0.8,0.0]), 3.0)#Lambertian(color([0.8,0.8,0.0]))
##pertext = NoiseTexture(SolidColor(color([4,4,4])))
##material_center = Lambertian(pertext)#Lambertian(SolidColor(color([0.8,0.3,0.1])))#Dielectric(1.5)
earth_image = Image.open('earthmap.jpg')
earth_image.load()
earth_image = torch.tensor(numpy.asarray(earth_image),dtype=DTYPE,device=device)
#print(earth_image.size())
earth_image.transpose_(1,0)
#print(earth_image.size())
##earth_texture = ImageTexture(earth_image)
##material_left = Lambertian(earth_texture)
#material_left = FuzzyDielectric(SolidColor(color(3*(1.5,))), SolidColor(color(3*(1.5,))))#DielectricMetallic(color([0.9,0.4,0.4]), 2.5, 2.5)#FuzzyDielectric(1.0, 1.5)#Metal(color([0.8,0.8,0.8]), 0.3)
##material_right = Metal(SolidColor(color([0.8,0.6,0.2])), SolidColor(color([2.8,1.6,1.2])))
#print(material_ground, material_center)
"""
world = Object3s()
# Spheres
world.add(Sphere(0, 1, point([ 0.0, -100.5, -1.0]), point([ 0.0, -100.5, -1.0]), 100.0, material_ground))
world.add(Sphere(0, 1, point([ 0.0, 0.2, -1.25]), point([ 0.0, 0.5 * torch.rand((1,)), -1.25]), 0.5, material_left))
#world.add(Sphere(0, 1, point([-1.25, 0.0, -1.0]), point([-1.25, 0.5 * torch.rand((1,)), -1.0]), 0.5, material_center))
world.add(Sphere(0, 1, point([-1.25, 0.0, -1.0]), point([-1.25, 0.0, -1.0]), 0.5, material_center))
world.add(Sphere(0, 1, point([ 1.25, 0.0, -1.0]), point([ 1.25, 0.0, -1.0]), 0.5, material_right))
# Lights
#world.add(Sphere(0, 1, point([ 0.0, 3.0, 0.0]), point([ 0.0, 3.0, 0.0]), 0.5, DiffuseLight(SolidColor(color([7.2,7.2,5.8])))))
#world.add(Sphere(0, 1, point([ -20.0, 1.0, 5.0]), point([ -20.0, 1.0, 5.0]), 0.25, DiffuseLight(SolidColor(color([0.1,0.3,0.1])))))
difflight = DiffuseLight(SolidColor(color((4,4,4))))
world.add(XY_Rect(3,5,1,3,-2,difflight))
# Bounding Volume Hiearchy
world_bvh = bvh_node(world, 0, 1)
# Camera
lookfrom = point((0,2,4))
lookat = point((0,0,-1))
viewup = vector((0,1,0))
aperture = 0.3
camera = Camera(lookfrom, lookat, viewup, 22.5, aspect, aperture, length(lookfrom-lookat), blur=half_normal, time0=0, time1=1)
background = color((0.0,0.0,0.0))#color((0.8,0.2,0.8))
"""
"""
# Cornell Box
cornell_box = Object3s()
red = Lambertian(SolidColor(color((.65,.05,.05))))
white = Lambertian(SolidColor(color((.73,.73,.73))))
green = Lambertian(SolidColor(color((.12,.45,.15))))
#light = DiffuseLight(SolidColor(color((15,15,15))))
light = DiffuseLight(SolidColor(color((7,7,7))))
cornell_box.add(YZ_Rect(0,555,0,555,555,green))
cornell_box.add(YZ_Rect(0,555,0,555,0,red))
#cornell_box.add(XZ_Rect(213,343,227,332,554,light))
cornell_box.add(XZ_Rect(113,443,127,432,554,light))
cornell_box.add(XZ_Rect(0,555,0,555,0,white))
cornell_box.add(XZ_Rect(0,555,0,555,555,white))
cornell_box.add(XY_Rect(0,555,0,555,555,white))
#cornell_box.add(Box(point((130,0,65)),point((295,165,230)),white))
#cornell_box.add(Box(point((265,0,295)),point((430,330,460)),white))
box1 = Box(point((0,0,0)), point((165,330,165)), white)
box1 = Rotate_Y(box1, 15)
box1 = Translate(box1, vector((265,0,295)))
box2 = Box(point((0,0,0)), point((165,165,165)), white)
box2 = Rotate_Y(box2, -18)
box2 = Translate(box2, vector((130,0,65)))
#cornell_box.add(box1)
#cornell_box.add(box2)
cornell_box.add(ConstantMedium(box1, 0.01, SolidColor(color((0,0,0)))))
cornell_box.add(ConstantMedium(box2, 0.01, SolidColor(color((1,1,1)))))
world_bvh = bvh_node(cornell_box)
background = color((0,0,0))
lookfrom = point((278,278,-800))
lookat = point((278,278,0))
viewup = vector((0,1,0))
vfov= 40.0
aperture = 0.0
camera = Camera(lookfrom, lookat, viewup, vfov, aspect, aperture, length(lookfrom-lookat), blur=half_normal, time0=0, time1=0)
"""
boxes1 = Object3s()
ground = Lambertian(SolidColor(color([0.48,0.83,0.53])))
boxes_per_side = 20
for i in range(boxes_per_side):
for j in range(boxes_per_side):
w = 100.0
x0 = -1000.0 + i*w
z0 = -1000.0 + j*w
y0 = 0.0
x1 = x0 + w
y1 = random.random()*100 + 1.0
z1 = z0 + w
boxes1.add(Box(point([x0,y0,z0]),point([x1,y1,z1]),ground))
objects = Object3s()
objects.add(bvh_node(boxes1, 0, 1))
light = DiffuseLight(SolidColor(color([7,7,7])))
objects.add(XZ_Rect(123,423,147,412,554,light))
center1 = point([400, 400, 200])
center2 = center1 + vector([30,0,0])
moving_sphere_material = Lambertian(SolidColor(color([0.7,0.3,0.1])))
objects.add(Sphere(0,1,center1,center2,50,moving_sphere_material))
center3 = point([260,150,45])
objects.add(Sphere(0,1,center3,center3,50,Dielectric(SolidColor(color([1.5,1.5,1.5])))))
center4 = point([0,150,145])
objects.add(Sphere(0,1,center4,center4,50,Metal(SolidColor(color([0.8,0.8,0.9])), SolidColor(color((1.0,1.0,1.0))))))
center5 = point([360,150,145])
boundary = Sphere(0,1,center5,center5, 70, Dielectric(SolidColor(color([1.5,1.5,1.5]))))
objects.add(boundary)
objects.add(ConstantMedium(boundary, 0.2, SolidColor(color([0.2,0.4,0.9]))))
boundary = Sphere(0,1,point([0,0,0]),point([0,0,0]), 5000, Dielectric(SolidColor(color([1.5,1.5,1.5]))))
objects.add(ConstantMedium(boundary, .0001, SolidColor(color([1,1,1]))))
emat = Lambertian(ImageTexture(earth_image))
objects.add(Sphere(0,1,point([400,200,400]),point([400,200,400]),100,emat))
pertext = NoiseTexture(SolidColor(color([0.1,0.1,0.1])))
objects.add(Sphere(0,1,point([220,280,300]),point([220,280,300]), 80, Lambertian(pertext)))
boxes2 = Object3s()
white = Lambertian(SolidColor(color([.73,.73,.73])))
ns = 1000
for j in range(ns):
center = random_in_unit_sphere((),dist=uniform) * 165
boxes2.add(Sphere(0,1,center,center,10,white))
objects.add(Translate(Rotate_Y(bvh_node(boxes2, 0.0, 1.0), 15), vector([-100,270,395])))
world = objects
background = color((0,0,0))
lookfrom = point((478,278,-600))
lookat = point((278,278,0))
viewup = vector((0,1,0))
vfov = 40.0
aperture = 0.0
camera = Camera(lookfrom, lookat, viewup, vfov, aspect, aperture, length(lookfrom-lookat), blur=half_normal, time0=0, time1=1)
def scatter(ray, rec, background, world, depth, cc_index):
c = 0.0
org = torch.zeros_like(ray.origin)
drc = torch.zeros_like(ray.origin)
tim = torch.zeros_like(ray.time)
tru = torch.zeros_like(ray.origin[:,:],dtype=torch.bool)
attenuations = []
for material in Material.instances.values():
emission = material.emitted(ray,rec,cc_index)
emission = torch.where(rec.target == material.index, emission, torch.zeros_like(emission))
attenuation = 0.0
t = material.scatter(ray,rec,cc_index)
if t is not None:
true,scattered,attenuation = t
true &= (rec.target == material.index)
true = true.unsqueeze(-1)
x = (true & ~tru)
org = torch.where(x, scattered.origin, org)
drc = torch.where(x, scattered.direction, drc)
tim = torch.where(x[...,0], scattered.time, tim)
tru = tru | true
attenuation = torch.where(true.squeeze(-1), attenuation, torch.zeros_like(attenuation))
attenuations.append((attenuation,emission))
scattered = Ray(org, drc, tim)
r = ray_color(scattered, background, world, depth-1, cc_index)
for attenuation, emission in attenuations:
c_ = emission + attenuation * r
c = c + c_
return c
def ray_color(ray, background: color, world, depth, cc_index):
if depth <= 0:
return 0.0
#rec = Record(ray)
#g = lights.hit(ray, torch.tensor((0.001,), dtype=DTYPE, device=device), inf, rec)
#g = lights.hit(ray, torch.tensor((0.001,), dtype=DTYPE, device=device), torch.tensor((500.0,), dtype=DTYPE, device=device), rec)
#d = light(ray,rec,cc_index)
#unit_direction = unit(ray.direction)
#t = .5*(unit_direction[:,:,1] + 1.0)
#b = (1. - t)*([1.,1.,1.][cc_index]) + t*([.5,.7,1.][cc_index])
#b = b * 0.99
#d = torch.where(g,d + b,b)
#d = torch.minimum(d, torch.ones_like(d))
#if depth <= 0:
# return d
#rec = Record(ray)
#h = world.hit(ray, torch.tensor((0.001,), dtype=DTYPE, device=device), inf, rec)
#h = world.hit(ray, torch.tensor((0.001,), dtype=DTYPE, device=device), torch.tensor((500.0,), dtype=DTYPE, device=device), rec)
#c = scatter(ray,rec,world,depth,cc_index)
#c = torch.minimum(c, torch.ones_like(c))
#return torch.where(h,c,d)
rec = Record(ray)
h = world.hit(ray, torch.tensor((0.001,), dtype=DTYPE, device=device), torch.tensor((500.0,), dtype=DTYPE, device=device), rec)
c = scatter(ray,rec,background, world,depth,cc_index)
c = torch.where(h,c,background[cc_index])
return c
display_window_width = 1024
display_window_height = int(display_window_width / aspect)
root = tkinter.Tk()
root.title(f'Rendering {fn}')
root.config(bg=None)
root.minsize(width=display_window_width, height=display_window_height)
root.geometry(f'{display_window_width}x{display_window_height}')
canvas = tkinter.Canvas(root,bg=None)
canvas.pack(expand=1,anchor=tkinter.CENTER, fill='both')
canvas_images = {}
canvas_rectangles = {}
resizing = False
def redraw(w,h):
nonlocal resizing
nonlocal display_window_width, display_window_height
display_window_width, display_window_height = w, h
for (i,j) in canvas_images:
item = canvas_images[i,j]
image = ImageTk.getimage(item[1])
image = image.resize((max(1,int(tile_size*w/width)),max(1,int(tile_size*h/height))),resample=PIL.Image.Resampling.BILINEAR)
image = ImageTk.PhotoImage(image=image)
item[1] = image
canvas.coords(item[0],i*w/width,max(0,(height-tile_size-j))*h/height)
x,y = canvas.coords(item[0])
canvas.itemconfig(item[0], {'image': image})
canvas.coords(canvas_rectangles[i,j],x,y,x+image.width(),y+image.height())
canvas.update_idletasks()
resizing = False
def resize(event):
nonlocal resizing
if event.widget == root:
resizing = True
w,h = event.width,event.height
h = min(w / aspect, event.height)
w = aspect * min(h, event.height)
canvas.configure(width=w,height=h)
redraw(w,h)
root.unbind_all('<Configure>')
root.bind('<Configure>', resize)
dt = torch.float16
if device == 'cpu':
dt = torch.bfloat16
with torch.autocast(device_type=device, dtype=dt, enabled=autocast):
beginning = time.time()
print(f'Rendering: {width} × {height} (× {samples_per_pixel} samples) ')
v = torch.arange(height, dtype=DTYPE, device='cpu')/(height-1)
u = torch.arange(width, dtype=DTYPE, device='cpu')/(width-1)
w = torch.zeros((height,width), dtype=DTYPE, device='cpu')
v = v.repeat((width,1)).transpose(1,0)
u = u.repeat((height,1))
uvw_full = torch.stack((u,v,w), dim=-1)
image = torch.zeros((height,width,3), dtype=DTYPE, device='cpu')
im = Image.new('RGB', (width, height))
k,j,i,cc_index = samples_per_pixel,0,0,0
def process():
nonlocal image, im
nonlocal resizing
nonlocal k,j,i,cc_index
now = time.time() - beginning
print(f'\rRemaining:\t{k:>4}\tTime:\t{now:>.2f}\t(seconds)', end='')
if k > 0:
rand_full = random_in_unit_disk(uvw_full.size(), uniform)#normal.sample(uvw_full.size())
uvw_full_ = uvw_full.to(device)
if cc_index < 3:
if j < height:
if i < width:
uvw = uvw_full_[j:j+tile_size,i:i+tile_size,:]
rand = rand_full[j:j+tile_size,i:i+tile_size,:]
rand[...,0] /= width-1
rand[...,1] /= height-1
rand[...,2] = 0
ray = camera.get_ray(uvw + rand)
p = ray_color(ray, background, world, max_depth, cc_index)
assert(not torch.any(p.isnan()))
image[j:j+tile_size,i:i+tile_size,cc_index] += p.to('cpu') / samples_per_pixel
image_ = image[j:j+tile_size,i:i+tile_size,:]
image_size = image_.size()
image_ = torch.minimum(image_, torch.ones_like(image_))
image_ = torch.sqrt(image_)
image_ = torch.floor(255.999*image_)
image_ = image_.to('cpu').numpy().astype(numpy.uint8)[::-1,:]
im = Image.fromarray(image_, 'RGB')
im_ = ImageTk.PhotoImage(image=im.resize((int(image_size[1]*display_window_width/width),int(image_size[0]*display_window_height/height)),resample=PIL.Image.Resampling.BILINEAR))
if (i,j) not in canvas_images:
img = canvas.create_image(i*display_window_width/width,max(0,(height-tile_size-j)*display_window_height/height),anchor=tkinter.NW,image=im_)
canvas_images[i,j] = [img,im_]
rect = canvas.create_rectangle(i*display_window_width/width,max(0,(height-tile_size-j)*display_window_height/height),i*display_window_width/width+image_size[1]*display_window_width/width+1,max(0,(height-tile_size-j)*display_window_height/height)+image_size[0]*display_window_height/height+1,
outline = "red", width = 3)
canvas_rectangles[i,j] = rect
else:
canvas.itemconfigure(canvas_images[i,j][0], {'image': im_})
canvas_images[i,j][1] = im_
canvas.update_idletasks()
i += tile_size
else:
j += tile_size
i = 0
else:
i = 0
j = 0
cc_index += 1
else:
k -= 1
j = 0
i = 0
cc_index = 0
root.after(50 if resizing else 0,process)
else:
pass
process()
root.mainloop()
image = torch.minimum(image, torch.ones_like(image))
image = torch.sqrt(image)
image = torch.floor(255.999*image)
im = Image.fromarray(image.to('cpu').numpy().astype(numpy.uint8)[::-1,:], 'RGB')
im.save(fn)
print('\nDone.')
main()
@omnp
Copy link
Author

omnp commented Nov 23, 2024

@omnp
Copy link
Author

omnp commented Nov 23, 2024

Forgot which parameters exactly were given on the command line but here is the image.
final_scene-cuda-correct-aspect

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment