Last active
November 17, 2025 16:38
-
-
Save timosarkar/fe2c9a15d409e90bf2f78d85400e6fbd to your computer and use it in GitHub Desktop.
tensor implementation in python & zig
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Tensor: | |
| def __init__(self, data): | |
| self.data = data | |
| self.shape = self.get_shape(data) | |
| def get_shape(self, data): | |
| if isinstance(data, list): # check if data is a list | |
| return (len(data),) + self.get_shape(data[0]) # calculate shape recursively | |
| else: | |
| return () # return empty tuple | |
| def __add__(self, other): | |
| if self.shape != other.shape: # check if self and other have same shape | |
| raise ValueError("Shapes do not match for addition") | |
| return Tensor(self.addr(self.data, other.data)) | |
| def addr(self, data1, data2): | |
| if isinstance(data1, list): # check if data1 is a list | |
| return [self.addr(d1, d2) for d1, d2 in zip(data1, data2)] # add elements recursively | |
| else: | |
| return data1 + data2 | |
| def __mul__(self, other): | |
| if self.shape != other.shape: # check if self and other have same shape | |
| raise ValueError("Shapes do not match for multiplication") | |
| return Tensor(self.mulr(self.data, other.data)) | |
| def mulr(self, data1, data2): | |
| if isinstance(data1, list): # check if data1 is a list | |
| return [self.mulr(d1, d2) for d1, d2 in zip(data1, data2)] # multiply elements recursively | |
| else: | |
| return data1 * data2 | |
| def __repr__(self): | |
| return f"Tensor(shape={self.shape}, data={self.data})" | |
| t1 = Tensor([[1, 2, 3], [4, 5, 6]]) | |
| t2 = Tensor([[7, 8, 9], [10, 11, 12]]) | |
| print("Addition Result:", t1 + t2) | |
| print("Multiplication Result:", t1 * t2) | |
| t5 = Tensor([ [[1, 2, 3],[1, 2, 3]], [[1, 2, 3],[1, 2, 3]] ]) | |
| t6 = Tensor([ [[4, 5, 6],[4, 5, 6]], [[4, 5, 6],[4, 5, 6]] ]) | |
| print(t5 + t6) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // zig build-exe tensor.zig | |
| const std = @import("std"); | |
| pub const Tensor = struct { | |
| data: []const f64, | |
| shape: []const usize, | |
| pub fn init(data: []const f64, shape: []const usize) Tensor { | |
| var total: usize = 1; | |
| for (shape) |dim| total *= dim; | |
| if (total != data.len) | |
| @panic("Shape does not match data length"); | |
| return .{ .data = data, .shape = shape }; | |
| } | |
| pub fn add(self: Tensor, other: Tensor, allocator: std.mem.Allocator) error{ShapeMismatch, OutOfMemory}!Tensor { | |
| if (!std.mem.eql(usize, self.shape, other.shape)) | |
| return error.ShapeMismatch; | |
| const out = try allocator.alloc(f64, self.data.len); | |
| for (self.data, other.data, 0..) |a, b, i| | |
| out[i] = a + b; | |
| return Tensor.init(out, self.shape); | |
| } | |
| pub fn mul(self: Tensor, other: Tensor, allocator: std.mem.Allocator) error{ShapeMismatch,OutOfMemory}!Tensor { | |
| if (!std.mem.eql(usize, self.shape, other.shape)) | |
| return error.ShapeMismatch; | |
| const out = try allocator.alloc(f64, self.data.len); | |
| for (self.data, other.data, 0..) |a, b, i| | |
| out[i] = a * b; | |
| return Tensor.init(out, self.shape); | |
| } | |
| pub fn print(self: Tensor) void { | |
| std.debug.print("Tensor(shape={any}, data={any})\n", | |
| .{ self.shape, self.data }); | |
| } | |
| }; | |
| pub fn main() !void { | |
| const allocator = std.heap.page_allocator; | |
| const t1 = Tensor.init((&[_]f64{ 1,2,3,4,5,6 })[0..], (&[_]usize{ 2,3 })[0..]); | |
| const t2 = Tensor.init((&[_]f64{ 7,8,9,10,11,12 })[0..],(&[_]usize{ 2,3 })[0..]); | |
| const t3 = try t1.add(t2, allocator); | |
| const t4 = try t1.mul(t2, allocator); | |
| t3.print(); | |
| t4.print(); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment