Skip to content

Instantly share code, notes, and snippets.

@uvolchyk
Created September 29, 2025 19:09
Show Gist options
  • Select an option

  • Save uvolchyk/0a0350d039e6fab89334fe465b7c0563 to your computer and use it in GitHub Desktop.

Select an option

Save uvolchyk/0a0350d039e6fab89334fe465b7c0563 to your computer and use it in GitHub Desktop.
import ModelIO
import MetalKit
public struct ObjectParser {
// mesh contains all the vertices, unordered
public let mesh: MTKMesh
// submesh takes the vertices and with indicies places them in order
public let submeshes: [MTKSubmesh]
public var textures: [MTLTexture]
public let mdlVertexDescriptor: MDLVertexDescriptor = {
// Position (0), Normal (1), Texcoord (2)
let mdlVertexDescriptor = MDLVertexDescriptor()
mdlVertexDescriptor.attributes[0] = MDLVertexAttribute(
name: MDLVertexAttributePosition,
format: .float3,
offset: 0,
bufferIndex: 0
)
mdlVertexDescriptor.attributes[1] = MDLVertexAttribute(
name: MDLVertexAttributeNormal,
format: .float3,
offset: 12,
bufferIndex: 0
)
mdlVertexDescriptor.attributes[2] = MDLVertexAttribute(
name: MDLVertexAttributeTextureCoordinate,
format: .float2,
offset: 24,
bufferIndex: 0
)
mdlVertexDescriptor.layouts[0] = MDLVertexBufferLayout(stride: 32)
return mdlVertexDescriptor
}()
// Replace/extend your init with texcoord attribute and material texture loading
public init(
modelURL: URL,
device: MTLDevice,
) {
let allocator = MTKMeshBufferAllocator(device: device)
let asset = MDLAsset(
url: modelURL,
vertexDescriptor: mdlVertexDescriptor,
bufferAllocator: allocator
)
// Grab the first mesh
let mdlMesh = asset.childObjects(of: MDLMesh.self).first as! MDLMesh
// Ensure normals exist if missing
if mdlMesh.vertexAttributeData(forAttributeNamed: MDLVertexAttributeNormal, as: .float3) == nil {
mdlMesh.addNormals(withAttributeNamed: MDLVertexAttributeNormal, creaseThreshold: 0.0)
}
// Build MTKMesh
let mesh = try! MTKMesh(mesh: mdlMesh, device: device)
self.mesh = mesh
self.submeshes = mesh.submeshes
let keys: [MDLMaterialSemantic] = [.baseColor]
let textureLoader = MTKTextureLoader(device: device)
var _textures: [any MTLTexture] = []
mdlMesh.submeshes?.forEach { submesh in
if
let mdlSubmesh = submesh as? MDLSubmesh,
let material = mdlSubmesh.material
{
for key in keys {
if let prop = material.property(with: key) {
// If it’s a texture sampler, use its URL
if
prop.type == .string,
let name = prop.stringValue
{
// Resolve relative to the OBJ’s folder
let texURL = modelURL.deletingLastPathComponent().appendingPathComponent(name)
if let tex = try? textureLoader.newTexture(URL: texURL, options: [
.SRGB: false as NSNumber,
.origin: MTKTextureLoader.Origin.bottomLeft
]) {
_textures.append(tex)
break
}
} else if
prop.type == .URL,
let url = prop.urlValue
{
// If MTL references a full URL
if let tex = try? textureLoader.newTexture(URL: url, options: [
.SRGB: false as NSNumber,
.origin: MTKTextureLoader.Origin.bottomLeft
]) {
_textures.append(tex)
break
}
} else if
prop.type == .texture,
let mdlTex = prop.textureSamplerValue?.texture
{
// Embedded MDLTexture
if let tex = try? textureLoader.newTexture(texture: mdlTex, options: [
.SRGB: false as NSNumber,
.origin: MTKTextureLoader.Origin.bottomLeft
]) {
_textures.append(tex)
break
}
}
}
}
}
}
self.textures = _textures
}
}
#include <metal_stdlib>
using namespace metal;
struct SceneUniforms {
float4x4 projection;
};
struct VertexIn {
float3 position [[attribute(0)]];
float3 normal [[attribute(1)]];
float2 uv [[attribute(2)]];
};
struct VertexOut {
float4 position [[position]];
float3 worldPosition;
float3 normal;
float2 uv;
};
vertex VertexOut modelVertex(
VertexIn in [[stage_in]],
constant SceneUniforms &u [[buffer(1)]],
constant float4x4 *instanceModels [[buffer(2)]],
uint instanceID [[instance_id]]
) {
// Get the model matrix for this instance
float4x4 model = instanceModels[instanceID];
VertexOut out;
float4 worldPos = model * float4(in.position, 1.0);
out.worldPosition = worldPos.xyz;
// Transform vertex position from model space to clip space
// Order: Model -> World -> View -> Projection
out.position = u.projection * worldPos;
// Transform the normal vector by the model matrix
out.normal = (model * float4(in.normal, 0.0)).xyz;
// Pass through UV coordinates unchaged
out.uv = in.uv;
return out;
}
struct PointLight {
float3 position;
float3 color;
float intensity;
float attenuation;
};
fragment float4 modelFragment(
VertexOut in [[stage_in]],
constant PointLight &light [[buffer(1)]],
texture2d<float> tex [[texture(0)]],
sampler samp [[sampler(0)]]
) {
// Sample the base texture color
float3 color = tex.sample(samp, in.uv).rgb;
// Calculate lighting
float3 normal = normalize(in.normal);
float3 lightDir = light.position - in.worldPosition;
float distance = length(lightDir);
lightDir = normalize(lightDir);
// Calculate attenuation (inverse square law with minimum distance)
float attenuation = 1.0 / (1.0 + light.attenuation * distance * distance);
// Calculate diffuse lighting (Lambertian)
float NdotL = max(dot(normal, lightDir), 0.0);
float3 diffuse = light.color * light.intensity * NdotL * attenuation;
// Combine albedo with lighting
float3 finalColor = color * diffuse;
return float4(finalColor, 1.0);
}
struct BlitVertexOut {
float4 position [[position]];
float2 uv;
};
vertex BlitVertexOut blitVertex(
uint vid [[vertex_id]]
) {
float2 pos[3] = { float2(-1, -1), float2(3, -1), float2(-1, 3) };
float2 uv[3] = { float2(0, 0), float2(2, 0), float2(0, 2) };
BlitVertexOut out;
out.position = float4(pos[vid], 0, 1);
out.uv = uv[vid];
return out;
}
fragment float4 blitFragment(
BlitVertexOut in [[stage_in]],
texture2d<float> src [[texture(0)]],
sampler samp [[sampler(0)]]
) {
float2 uv = float2(in.uv.x, 1.0 - in.uv.y);
return float4(src.sample(samp, uv).rgb, 1.0);
}
struct PostProcessUniforms {
float2 viewportSize; // in pixels
float pixelSize; // pixelation block size in pixels
float lineThickness; // grid line thickness in pixels
float3 gridColor; // RGB for grid
float gridAlpha; // alpha for grid overlay
};
struct PostProcessVertexOut {
float4 position [[position]];
float2 uv;
};
// Fullscreen triangle vertex shader
vertex PostProcessVertexOut quantVertex(
uint vid [[vertex_id]]
) {
PostProcessVertexOut out;
float2 pos[3] = {
float2(-1.0, -1.0),
float2( 3.0, -1.0),
float2(-1.0, 3.0)
};
float2 uv[3] = {
float2(0.0, 0.0),
float2(2.0, 0.0),
float2(0.0, 2.0)
};
out.position = float4(pos[vid], 0.0, 1.0);
out.uv = uv[vid];
return out;
}
fragment float4 quantFragment(
PostProcessVertexOut in [[stage_in]],
constant PostProcessUniforms &u [[buffer(0)]],
texture2d<float> colorTex [[texture(0)]],
sampler samp [[sampler(0)]]
) {
// Original Texture
float2 texSize = u.viewportSize;
float2 uv = float2(in.uv.x, 1.0 - in.uv.y);
// Convert UV to pixel coordinates
float2 px = uv * texSize;
// Snap to grid
float2 block = floor(px / u.pixelSize) * u.pixelSize + 0.5 * u.pixelSize;
// Back to normalized UV coordinates
float2 qUV = block / texSize;
// Sample the scene color at the block center
float3 base = colorTex.sample(samp, qUV).rgb;
// Grid overlay: draw lines where we are close to the block edges
float2 modv = fmod(px, u.pixelSize);
float2 distToEdge = min(modv, u.pixelSize - modv);
float edgeDist = min(distToEdge.x, distToEdge.y);
float lineMask = smoothstep(u.lineThickness + 0.6, u.lineThickness, edgeDist);
float3 gridRGB = u.gridColor;
float3 colorWithGrid = mix(base, gridRGB, lineMask * u.gridAlpha);
return float4(colorWithGrid, 1.0);
}
import SwiftUI
import MetalKit
struct PointLight {
var position: SIMD3<Float>
var color: SIMD3<Float>
var intensity: Float
var attenuation: Float
}
struct MetalViewRepresentable: UIViewRepresentable {
final class Coordinator: NSObject {
let renderer = Renderer()
weak var mtkView: MTKView?
}
func makeCoordinator() -> Coordinator {
Coordinator()
}
func makeUIView(context: Context) -> MTKView {
let device = context.coordinator.renderer.device
let view = MTKView(frame: .zero, device: device)
view.clearColor = MTLClearColorMake(0, 0, 0, 1)
view.colorPixelFormat = .bgra8Unorm
view.depthStencilPixelFormat = .depth32Float
view.preferredFramesPerSecond = 60
view.isPaused = false
view.enableSetNeedsDisplay = false
view.framebufferOnly = true
view.delegate = context.coordinator.renderer
context.coordinator.mtkView = view
return view
}
func updateUIView(
_ uiView: MTKView,
context: Context
) {}
}
final class Renderer: NSObject {
let library: ShaderLibrary
let device: any MTLDevice = MTLCreateSystemDefaultDevice()!
let commandQueue: any MTLCommandQueue
let mdlObject: ObjectParser
private var animationTime: Float = 0.0
private var rotationSpeed: Float = 1.0
private var instanceTransforms: [AffineTransform] = [
AffineTransform(
translation: SIMD3<Float>(0.0, -10.0, 0.0),
scale: SIMD3<Float>(repeating: 0.6)
)
]
private let instanceBuffer: MTLBuffer
init(
modelURL: URL = Bundle.main.url(
forResource: "12973_anemone_flower_v1_l2",
withExtension: "obj"
)!
) {
library = .init(
library: try! device.makeDefaultLibrary(bundle: .main)
)
commandQueue = device.makeCommandQueue()!
mdlObject = ObjectParser(
modelURL: modelURL,
device: device
)
instanceBuffer = device.makeBuffer(
length: MemoryLayout<float4x4>.stride * instanceTransforms.count,
options: []
)!
super.init()
}
}
extension Renderer: MTKViewDelegate {
func draw(in view: MTKView) {
guard
let drawable = view.currentDrawable,
let commandQueue = device.makeCommandQueue(),
let commandBuffer = commandQueue.makeCommandBuffer()
else {
return
}
animationTime += 0.016
let width = max(1, Int(view.drawableSize.width))
let height = max(1, Int(view.drawableSize.height))
// ---- START: MODEL ----
let modelTextureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: view.colorPixelFormat,
width: width,
height: height,
mipmapped: false
)
modelTextureDescriptor.usage = [.renderTarget, .shaderRead]
modelTextureDescriptor.storageMode = .private
modelTextureDescriptor.textureType = .type2D
let modelTexture = device.makeTexture(descriptor: modelTextureDescriptor)
let modelDepthDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: .depth32Float,
width: width,
height: height,
mipmapped: false
)
modelDepthDescriptor.usage = [.renderTarget]
modelDepthDescriptor.storageMode = .private
modelDepthDescriptor.textureType = .type2D
let modelDepthTexture = device.makeTexture(descriptor: modelDepthDescriptor)
let offscreenPassDesctiptor = MTLRenderPassDescriptor()
offscreenPassDesctiptor.colorAttachments[0].texture = modelTexture
offscreenPassDesctiptor.colorAttachments[0].loadAction = .clear
offscreenPassDesctiptor.colorAttachments[0].storeAction = .store
offscreenPassDesctiptor.colorAttachments[0].clearColor = MTLClearColorMake(0, 0, 0, 1)
offscreenPassDesctiptor.depthAttachment.texture = modelDepthTexture
offscreenPassDesctiptor.depthAttachment.loadAction = .clear
offscreenPassDesctiptor.depthAttachment.storeAction = .dontCare
offscreenPassDesctiptor.depthAttachment.clearDepth = 1.0
if
let renderEncoder = commandBuffer.makeRenderCommandEncoder(
descriptor: offscreenPassDesctiptor
)
{
do {
try drawModel(in: view, renderEncoder: renderEncoder)
} catch {
fatalError(error.localizedDescription)
}
renderEncoder.endEncoding()
}
// ---- END: MODEL ----
// ---- START: POST-PROCESS ----
let postProcessTextureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: view.colorPixelFormat,
width: width,
height: height,
mipmapped: false
)
postProcessTextureDescriptor.usage = [.renderTarget, .shaderRead]
postProcessTextureDescriptor.storageMode = .private
let postProcessTexture = device.makeTexture(descriptor: postProcessTextureDescriptor)
let postProcessPassDescriptor = MTLRenderPassDescriptor()
postProcessPassDescriptor.colorAttachments[0].texture = postProcessTexture
postProcessPassDescriptor.colorAttachments[0].loadAction = .clear
postProcessPassDescriptor.colorAttachments[0].storeAction = .store
postProcessPassDescriptor.colorAttachments[0].clearColor = MTLClearColorMake(0, 0, 0, 1)
if
let renderEncoder = commandBuffer.makeRenderCommandEncoder(
descriptor: postProcessPassDescriptor
),
let texture = modelTexture
{
do {
try drawQuants(
in: view,
renderEncoder: renderEncoder,
texture: texture
)
} catch {
fatalError(error.localizedDescription)
}
renderEncoder.endEncoding()
}
// ---- END: POST-PROCESS ----
// ---- START: SCENE BLIT ----
if
let sceneRenderPassDescriptor = view.currentRenderPassDescriptor,
let renderEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: sceneRenderPassDescriptor),
let texture = postProcessTexture
{
do {
try drawBlit(
it: view,
renderEncoder: renderEncoder,
sceneTexture: texture
)
} catch {
fatalError(error.localizedDescription)
}
renderEncoder.endEncoding()
}
// ---- END: SCENE BLIT ----
commandBuffer.present(drawable)
commandBuffer.commit()
}
func mtkView(
_ view: MTKView,
drawableSizeWillChange size: CGSize
) {}
}
extension Renderer {
func drawModel(
in view: MTKView,
renderEncoder: any MTLRenderCommandEncoder
) throws {
let pipelineDescriptor = MTLRenderPipelineDescriptor()
pipelineDescriptor.vertexDescriptor = MTKMetalVertexDescriptorFromModelIO(mdlObject.mdlVertexDescriptor)
pipelineDescriptor.vertexFunction = try! library.modelVertex
pipelineDescriptor.fragmentFunction = try! library.modelFragment
pipelineDescriptor.colorAttachments[0].pixelFormat = view.colorPixelFormat
pipelineDescriptor.depthAttachmentPixelFormat = .depth32Float
let renderPipelineState = try device.makeRenderPipelineState(descriptor: pipelineDescriptor)
renderEncoder.setRenderPipelineState(renderPipelineState)
let depthStencilDescriptor = MTLDepthStencilDescriptor()
depthStencilDescriptor.depthCompareFunction = .less
depthStencilDescriptor.isDepthWriteEnabled = true
if
let depthStencilState = device.makeDepthStencilState(
descriptor: depthStencilDescriptor
)
{
renderEncoder.setDepthStencilState(depthStencilState)
}
renderEncoder.setVertexBuffer(
mdlObject.mesh.vertexBuffers[0].buffer,
offset: mdlObject.mesh.vertexBuffers[0].offset,
index: 0
)
let aspect = Float(view.drawableSize.width / max(1, view.drawableSize.height))
let perspectiveMatrix = AffineTransform.perspective(
fovyRadians: .pi / 4,
aspect: aspect,
near: 0.1,
far: 1000
)
let viewMatrix = AffineTransform.lookAt(
eye: SIMD3<Float>(0.0, 0.0, 40.0),
center: SIMD3<Float>(0.0, 0.0, 0.0),
up: SIMD3<Float>(0.0, 1.0, 0.0),
)
var uniforms: SceneUniforms = .init(
projection: perspectiveMatrix * viewMatrix
)
renderEncoder.setVertexBytes(
&uniforms,
length: MemoryLayout<SceneUniforms>.stride,
index: 1
)
let ptr = instanceBuffer
.contents()
.bindMemory(
to: float4x4.self,
capacity: instanceTransforms.count
)
for (i, t) in instanceTransforms.enumerated() {
var _t = t
_t.rotation = simd_quatf(
angle: rotationSpeed * animationTime,
axis: SIMD3<Float>(0, 1, 0)
)
ptr[i] = _t.modelMatrix
}
renderEncoder.setVertexBuffer(
instanceBuffer,
offset: 0,
index: 2
)
var pointLight = PointLight(
position: SIMD3<Float>(0.0, 0.0, 10.0),
color: SIMD3<Float>(1.0, 1.0, 1.0),
intensity: 16.0,
attenuation: 0.1
)
let sampDesc = MTLSamplerDescriptor()
sampDesc.minFilter = .linear
sampDesc.magFilter = .linear
sampDesc.sAddressMode = .repeat
sampDesc.tAddressMode = .repeat
let sampler = device.makeSamplerState(descriptor: sampDesc)
renderEncoder.setFragmentSamplerState(sampler, index: 0)
renderEncoder.setFragmentBytes(
&pointLight,
length: MemoryLayout<PointLight>.stride,
index: 1
)
for (submesh, texture) in zip(mdlObject.submeshes, mdlObject.textures) {
renderEncoder.setFragmentTexture(texture, index: 0)
renderEncoder.drawIndexedPrimitives(
type: submesh.primitiveType,
indexCount: submesh.indexCount,
indexType: submesh.indexType,
indexBuffer: submesh.indexBuffer.buffer,
indexBufferOffset: submesh.indexBuffer.offset,
instanceCount: instanceTransforms.count
)
}
}
func drawBlit(
it view: MTKView,
renderEncoder: any MTLRenderCommandEncoder,
sceneTexture: any MTLTexture
) throws {
let pipelineDescriptor = MTLRenderPipelineDescriptor()
pipelineDescriptor.vertexFunction = try! library.blitVertex
pipelineDescriptor.fragmentFunction = try! library.blitFragment
pipelineDescriptor.colorAttachments[0].pixelFormat = view.colorPixelFormat
pipelineDescriptor.depthAttachmentPixelFormat = view.depthStencilPixelFormat
let pipelineState = try device.makeRenderPipelineState(descriptor: pipelineDescriptor)
renderEncoder.setRenderPipelineState(pipelineState)
let samplerDescriptor = MTLSamplerDescriptor()
samplerDescriptor.minFilter = .linear
samplerDescriptor.magFilter = .linear
samplerDescriptor.sAddressMode = .clampToEdge
samplerDescriptor.tAddressMode = .clampToEdge
let sampler = device.makeSamplerState(descriptor: samplerDescriptor)!
renderEncoder.setFragmentTexture(sceneTexture, index: 0)
renderEncoder.setFragmentSamplerState(sampler, index: 0)
renderEncoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: 3)
}
func drawQuants(
in view: MTKView,
renderEncoder: any MTLRenderCommandEncoder,
texture: any MTLTexture
) throws {
struct PPUniforms {
var viewportSize: SIMD2<Float>
var pixelSize: Float
var lineThickness: Float
var gridColor: SIMD3<Float>
var gridAlpha: Float
}
let samplerDescriptor = MTLSamplerDescriptor()
samplerDescriptor.minFilter = .nearest
samplerDescriptor.magFilter = .nearest
samplerDescriptor.sAddressMode = .clampToEdge
samplerDescriptor.tAddressMode = .clampToEdge
let sampler = device.makeSamplerState(descriptor: samplerDescriptor)
let pipelineDescriptor = MTLRenderPipelineDescriptor()
pipelineDescriptor.vertexFunction = try! library.quantVertex
pipelineDescriptor.fragmentFunction = try! library.quantFragment
pipelineDescriptor.colorAttachments[0].pixelFormat = view.colorPixelFormat
let renderPipelineState = try device.makeRenderPipelineState(descriptor: pipelineDescriptor)
let width = max(1, Int(view.drawableSize.width))
let height = max(1, Int(view.drawableSize.height))
var ppUniforms = PPUniforms(
viewportSize: SIMD2<Float>(Float(width), Float(height)),
pixelSize: 14.0, // size of each pixel block in screen pixels
lineThickness: 1.0, // grid line thickness in pixels
gridColor: SIMD3<Float>(0.1, 0.1, 0.1), // dark grid
gridAlpha: 0.35 // grid opacity
)
renderEncoder.setRenderPipelineState(renderPipelineState)
renderEncoder.setFragmentTexture(texture, index: 0)
renderEncoder.setFragmentSamplerState(sampler, index: 0)
renderEncoder.setFragmentBytes(&ppUniforms, length: MemoryLayout<PPUniforms>.stride, index: 0)
renderEncoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: 3)
}
}
import MetalKit
/// A wrapper around MTLLibrary that provides convenient shader function access
/// using Swift's @dynamicMemberLookup feature.
@dynamicMemberLookup
public struct ShaderLibrary {
/// The underlying Metal library
let library: MTLLibrary
public init(
library: MTLLibrary
) {
self.library = library
}
/// Retrieves a shader function by name
/// - Parameter name: The name of the shader function
/// - Returns: The Metal function
/// - Throws: An error if the function cannot be found
public func function(named name: String) throws -> MTLFunction {
let function = try library.makeFunction(
name: name,
constantValues: .init()
)
return function
}
public subscript(
dynamicMember member: String
) -> MTLFunction {
get throws {
try function(named: member)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment