Created
December 2, 2022 20:04
-
-
Save daym/3eb52563358fcedac4c5fdd2f020f75c to your computer and use it in GitHub Desktop.
Matrices and Vectors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #![feature(array_zip)] | |
| use core::iter::Sum; | |
| use core::ops::{Add, Sub, Mul, Div}; | |
| use num_traits::{Zero, One}; | |
| // TODO: AddAssign and so on (sigh) | |
| pub trait Field: Add<Output=Self> + Sub<Output=Self> + Mul<Output=Self> + Div<Output=Self> + Sized + Copy + Zero + One + PartialEq + Sum {} | |
| impl Field for f32 {} | |
| impl Field for i32 {} | |
| impl Field for f64 {} | |
| impl Field for i64 {} | |
| // TODO: AddAssign and so on (sigh) | |
| pub trait Vector<F: Field>: Add<Output=Self> + Sub<Output=Self> + Mul<F> + Zero {} | |
| trait Matrix<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize>: Zero + Add + Sub + Sized { // : Mul<Matrix<F, X_COUNT, COLUMN_COUNT>> + Sized | |
| } | |
| // MatrixRc | |
| #[derive(Debug, PartialEq)] | |
| pub struct MatrixRc<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize>([[F; COLUMN_COUNT]; ROW_COUNT]); | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Zero for MatrixRc<F, ROW_COUNT, COLUMN_COUNT> { | |
| fn zero() -> Self { | |
| let result = [[F::zero(); COLUMN_COUNT]; ROW_COUNT]; | |
| MatrixRc(result) | |
| } | |
| fn is_zero(&self) -> bool { | |
| let result = [[F::zero(); COLUMN_COUNT]; ROW_COUNT]; | |
| self.0 == result | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Matrix<F, ROW_COUNT, COLUMN_COUNT> for MatrixRc<F, ROW_COUNT, COLUMN_COUNT> {} | |
| impl<F: Field, const ROW_COUNT: usize, const X_COUNT: usize, const COLUMN_COUNT: usize> Mul<MatrixRc<F, X_COUNT, COLUMN_COUNT>> for MatrixRc<F, ROW_COUNT, X_COUNT> { | |
| type Output = MatrixRc<F, ROW_COUNT, COLUMN_COUNT>; | |
| fn mul(self, rhs: MatrixRc<F, X_COUNT, COLUMN_COUNT>) -> Self::Output { | |
| let mut result = [[F::zero(); COLUMN_COUNT]; ROW_COUNT]; | |
| for i in 0..ROW_COUNT { | |
| for k in 0..COLUMN_COUNT { | |
| for j in 0..X_COUNT { | |
| result[i][k] = result[i][k] + self.0[i][j] * rhs.0[j][k] | |
| } | |
| } | |
| } | |
| MatrixRc(result) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Add for MatrixRc<F, ROW_COUNT, COLUMN_COUNT> { | |
| type Output = Self; | |
| fn add(self, rhs: Self) -> Self::Output { | |
| let mut result = [[F::zero(); COLUMN_COUNT]; ROW_COUNT]; | |
| for i in 0..ROW_COUNT { | |
| for j in 0..COLUMN_COUNT { | |
| result[i][j] = self.0[i][j] + rhs.0[i][j] | |
| } | |
| } | |
| Self(result) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Sub for MatrixRc<F, ROW_COUNT, COLUMN_COUNT> { | |
| type Output = Self; | |
| fn sub(self, rhs: Self) -> Self::Output { | |
| let mut result = [[F::zero(); COLUMN_COUNT]; ROW_COUNT]; | |
| for i in 0..ROW_COUNT { | |
| for j in 0..COLUMN_COUNT { | |
| result[i][j] = self.0[i][j] - rhs.0[i][j] | |
| } | |
| } | |
| Self(result) | |
| } | |
| } | |
| // MatrixCr | |
| // TODO: Maybe impl Debug so that it's still row, then column | |
| // TODO: maybe a reference instead; or a Cow | |
| #[derive(Debug, PartialEq)] | |
| pub struct MatrixCr<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize>([[F; ROW_COUNT]; COLUMN_COUNT]); | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Add for MatrixCr<F, ROW_COUNT, COLUMN_COUNT> { | |
| type Output = Self; | |
| fn add(self, rhs: Self) -> Self::Output { | |
| let mut result = [[F::zero(); ROW_COUNT]; COLUMN_COUNT]; | |
| for i in 0..COLUMN_COUNT { | |
| for j in 0..ROW_COUNT { | |
| result[i][j] = self.0[i][j] + rhs.0[i][j] | |
| } | |
| } | |
| Self(result) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Sub for MatrixCr<F, ROW_COUNT, COLUMN_COUNT> { | |
| type Output = Self; | |
| fn sub(self, rhs: Self) -> Self::Output { | |
| let mut result = [[F::zero(); ROW_COUNT]; COLUMN_COUNT]; | |
| for i in 0..COLUMN_COUNT { | |
| for j in 0..ROW_COUNT { | |
| result[i][j] = self.0[i][j] - rhs.0[i][j] | |
| } | |
| } | |
| Self(result) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Zero for MatrixCr<F, ROW_COUNT, COLUMN_COUNT> { | |
| fn zero() -> Self { | |
| let result = [[F::zero(); ROW_COUNT]; COLUMN_COUNT]; | |
| Self(result) | |
| } | |
| fn is_zero(&self) -> bool { | |
| let result = [[F::zero(); ROW_COUNT]; COLUMN_COUNT]; | |
| self.0 == result | |
| } | |
| } | |
| pub type ColumnVector<F, const DIM: usize> = MatrixCr<F, DIM, 1>; | |
| impl<F: Field, const ROW_COUNT: usize> ColumnVector<F, ROW_COUNT> { | |
| pub fn obi(&self, index: usize) -> F { | |
| assert!(index > 0); | |
| self.0[0][index - 1] | |
| } | |
| pub fn new_from_coordinates(coordinates: [F; ROW_COUNT]) -> Self { | |
| Self([coordinates]) | |
| } | |
| } | |
| impl<F: Field> ColumnVector<F, 3> { | |
| pub fn cross(&self, b: &Self) -> Self { | |
| let a = self; | |
| Self::new_from_coordinates([a.obi(2) * b.obi(3) - a.obi(3) * b.obi(2), a.obi(1) * b.obi(3) - a.obi(3) * b.obi(1), a.obi(1) * b.obi(2) - a.obi(2) * b.obi(1)]) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Mul<F> for MatrixRc<F, ROW_COUNT, COLUMN_COUNT> { | |
| type Output = Self; | |
| fn mul(self, rhs: F) -> Self { | |
| Self(self.0.map(|a| a.map(|aa| aa * rhs))) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> Mul<F> for MatrixCr<F, ROW_COUNT, COLUMN_COUNT> { | |
| type Output = Self; | |
| fn mul(self, rhs: F) -> Self { | |
| Self(self.0.map(|a| a.map(|aa| aa * rhs))) | |
| } | |
| } | |
| impl<F: Field, const DIM: usize> Vector<F> for ColumnVector<F, DIM> {} | |
| impl<F: Field, const DIM: usize> ColumnVector<F, DIM> { | |
| pub fn dot(&self, rhs: &Self) -> F { | |
| self.0[0].zip(rhs.0[0]).map(|(a,b)| a * b).into_iter().sum() | |
| } | |
| } | |
| // FIXME: Or impl From | |
| // impl<F: Field, const ROW_COUNT: usize> Matrix<F, ROW_COUNT, 1> for ColumnVector<F, ROW_COUNT> { | |
| // } | |
| // FIXME: or ColumnVector = MatrixCr<F, ROW_COUNT, 1> | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> MatrixRc<F, ROW_COUNT, COLUMN_COUNT> { | |
| /// Column 1-based index | |
| pub fn cobi(&self, index: usize) -> ColumnVector<F, ROW_COUNT> { | |
| assert!(index > 0); | |
| let index = index - 1; | |
| let mut result: [F; ROW_COUNT] = [F::zero(); ROW_COUNT]; | |
| for i in 0..ROW_COUNT { | |
| result[i] = self.0[i][index] | |
| } | |
| ColumnVector::new_from_coordinates(result) | |
| } | |
| /// Construct a matrix from the given column vectors | |
| pub fn new(columns: [ColumnVector<F, ROW_COUNT>; COLUMN_COUNT]) -> Self { | |
| let mut result: [[F; COLUMN_COUNT]; ROW_COUNT] = [[F::zero(); COLUMN_COUNT]; ROW_COUNT]; | |
| for i in 0..ROW_COUNT { | |
| for j in 0..COLUMN_COUNT { | |
| result[i][j] = columns[j].0[0][i] | |
| } | |
| } | |
| Self(result) | |
| } | |
| } | |
| impl<F: Field, const ROW_COUNT: usize, const COLUMN_COUNT: usize> MatrixCr<F, ROW_COUNT, COLUMN_COUNT> { | |
| /// Column 1-based index | |
| pub fn cobi(&self, index: usize) -> ColumnVector<F, ROW_COUNT> { | |
| assert!(index > 0); | |
| ColumnVector::new_from_coordinates(self.0[index - 1]) | |
| } | |
| /// Construct a matrix from the given column vectors | |
| pub fn new(columns: [ColumnVector<F, ROW_COUNT>; COLUMN_COUNT]) -> Self { | |
| Self(columns.map(|x| x.0[0])) | |
| } | |
| } | |
| fn main() { | |
| let v = ColumnVector::new_from_coordinates( | |
| [1.0, | |
| 3.0, | |
| 4.0]); | |
| let w = ColumnVector::new_from_coordinates([1.0, | |
| 0.0, | |
| 4.0]); | |
| assert_eq!(v + w, ColumnVector::new_from_coordinates([2.0, | |
| 3.0, | |
| 8.0])); | |
| let v = ColumnVector::new_from_coordinates([1.0, | |
| 3.0, | |
| 4.0]); | |
| assert_eq!(v.cross(&v), ColumnVector::zero()); | |
| let v = ColumnVector::new_from_coordinates([1.0, | |
| 3.0, | |
| 4.0]); | |
| println!("v = {:?}", v); | |
| assert_eq!(v.dot(&v), 26.0); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment