Skip to content

Instantly share code, notes, and snippets.

@t0yv0
Created September 10, 2025 05:41
Show Gist options
  • Select an option

  • Save t0yv0/9bd27c5d3d58c90f547b8dcd1227d417 to your computer and use it in GitHub Desktop.

Select an option

Save t0yv0/9bd27c5d3d58c90f547b8dcd1227d417 to your computer and use it in GitHub Desktop.
(*
Construct a tensor product.
Given:
- a vector space X
- and a vector space Y
Construct:
- a vector space Z
- with pack: (X, Y) -> Z
- such that pack is a bi-linear map
- and for every bi-linear map m : (X, Y) -> U there exists a linear map m' : Z -> U such that m(x,y) = m'(pack(x, y))
Ref: https://math.berkeley.edu/~serganov/math252/tens.pdf
*)
Class Field (F : Type) := {
(* Operations *)
zero : F;
one : F;
add : F -> F -> F;
mul : F -> F -> F;
neg : F -> F;
inv : F -> F;
(* Decidable equality *)
eq_dec : forall x y : F, {x = y} + {x <> y};
(* Field axioms *)
(* Addition forms an abelian group *)
add_assoc : forall a b c, add (add a b) c = add a (add b c);
add_comm : forall a b, add a b = add b a;
add_zero : forall a, add a zero = a;
add_neg : forall a, add a (neg a) = zero;
(* Multiplication forms an abelian group on non-zero elements *)
mul_assoc : forall a b c, mul (mul a b) c = mul a (mul b c);
mul_comm : forall a b, mul a b = mul b a;
mul_one : forall a, mul a one = a;
mul_inv : forall a, a <> zero -> mul a (inv a) = one;
(* Distributivity *)
mul_add_distr : forall a b c, mul a (add b c) = add (mul a b) (mul a c);
(* Zero and one are distinct *)
zero_neq_one : zero <> one
}.
Class VectorSpace (F : Type) `{Field F} (V : Type) := {
(* Operations *)
vadd : V -> V -> V;
smul : F -> V -> V;
vzero : V;
vneg : V -> V;
(* Axioms *)
vadd_assoc : forall u v w, vadd u (vadd v w) = vadd (vadd u v) w;
vadd_comm : forall u v, vadd u v = vadd v u;
vadd_zero_l : forall v, vadd vzero v = v;
vadd_neg_l : forall v, vadd (vneg v) v = vzero;
(* Scalar multiplication axioms *)
smul_distrib_l : forall (k : F) (u v : V),
smul k (vadd u v) = vadd (smul k u) (smul k v);
smul_distrib_r : forall (k l : F) (v : V),
smul (add k l) v = vadd (smul k v) (smul l v);
smul_assoc : forall (k l : F) (v : V),
smul (mul k l) v = smul k (smul l v);
smul_one : forall v, smul one v = v
}.
Section LinMaps.
Context {F : Type} `{Field F}.
Context {U V : Type} `{VectorSpace F U} `{VectorSpace F V}.
Record LinMap : Type :=
mkLinMap
{ linmap : (U -> V)
; linmap_additive : forall x1 x2 : U, linmap (vadd x1 x2) = vadd (linmap x1) (linmap x2)
; linmap_homogen : forall (c : F) (x: U), linmap (smul c x) = smul c (linmap x)
}.
End LinMaps.
Section TensorProduct.
Context {F : Type} `{Field F}.
Context {X Y : Type} `{VectorSpace F X} `{VectorSpace F Y}.
Record BilinearMap (U : Type) `{VectorSpace F U} : Type :=
mkBiLinMap
{ blmap : (X -> Y -> U)
; linmap1 : Y -> LinMap (U := X) (V := U)
; linmap2 : X -> LinMap (U := Y) (V := U)
; linmap1_correct : forall (x: X) (y: Y), blmap x y = linmap (linmap1 y) x
; linmap2_correct : forall (x: X) (y: Y), blmap x y = linmap (linmap2 x) y
}.
(* Represent Z the product of X and Y as a sampler of any bi-linear map (X,Y)-> Z at some implied X0, Y0 *)
Inductive Z: Type :=
| make_z : (forall {U: Type} `{VectorSpace F U}, BilinearMap U -> U) -> Z.
Definition unz {U: Type} `{VectorSpace F U} (z: Z) (bm: BilinearMap U) : U :=
match z with
| make_z m => m U _ _ bm
end.
(* Assume Z equality if for every bi-linear the two samplers agree. *)
Axiom zeq : forall (U: Type) (z1 z2 : Z),
(forall bm, unz z1 bm = unz z2 bm) -> z1 = z2.
Definition zadd (z1 z2 : Z) : Z :=
make_z (fun _ _ _ bm => vadd (unz z1 bm) (unz z2 bm)).
Lemma zadd_assoc : forall u v w, zadd u (zadd v w) = zadd (zadd u v) w.
intros.
apply zeq.
auto.
intros.
simpl.
rewrite <- vadd_assoc.
trivial.
Qed.
Lemma zadd_comm : forall u v, zadd u v = zadd v u.
intros.
apply zeq.
auto.
intros.
unfold zadd.
simpl.
rewrite <- vadd_comm.
trivial.
Qed.
Definition zzero : Z :=
make_z (fun _ _ _ _ => vzero).
Definition zneg (z : Z) : Z :=
make_z (fun _ _ _ bm => vneg (unz z bm)).
Lemma zadd_zero_l : forall z, zadd zzero z = z.
intros.
destruct z.
apply zeq.
auto.
intros.
apply vadd_zero_l.
Qed.
Lemma zadd_neg_l : forall z, zadd (zneg z) z = zzero.
intros.
destruct z.
apply zeq.
auto.
intros.
unfold zneg.
apply vadd_neg_l.
Qed.
Definition zsmul (f: F) (z: Z): Z :=
make_z (fun _ _ _ bm => smul f (unz z bm)).
Lemma z_smul_distrib_l : forall (k : F) (u v : Z),
zsmul k (zadd u v) = zadd (zsmul k u) (zsmul k v).
intros.
unfold zsmul.
unfold zadd.
apply zeq.
auto.
intros.
simpl.
rewrite smul_distrib_l.
auto.
Qed.
Lemma z_smul_distrib_r : forall (k l : F) (v : Z),
zsmul (add k l) v = zadd (zsmul k v) (zsmul l v).
intros.
unfold zsmul.
unfold zadd.
destruct v.
apply zeq.
auto.
intros.
simpl.
apply smul_distrib_r.
Qed.
Lemma z_smul_assoc : forall (k l : F) (v : Z),
zsmul (mul k l) v = zsmul k (zsmul l v).
intros.
unfold zsmul.
destruct v.
apply zeq.
auto.
intros.
simpl.
apply smul_assoc.
Qed.
Lemma z_smul_one : forall (z : Z), zsmul one z = z.
intros.
unfold zsmul.
destruct z.
apply zeq.
auto.
intros.
apply smul_one.
Qed.
Instance Z_VectorSpace : VectorSpace F Z :=
{
vadd := zadd;
smul := zsmul;
vzero := zzero;
vneg := zneg;
vadd_assoc := zadd_assoc;
vadd_comm := zadd_comm;
vadd_zero_l := zadd_zero_l;
vadd_neg_l := zadd_neg_l;
smul_distrib_l := z_smul_distrib_l;
smul_distrib_r := z_smul_distrib_r;
smul_assoc := z_smul_assoc;
smul_one := z_smul_one;
}.
Definition pack (x : X) (y : Y) : Z :=
make_z (fun _ _ _ bm => blmap _ bm x y).
Definition pack1 (x: X) : (Y -> Z) := pack x.
Definition pack2 (y: Y) : (X -> Z) := fun x => pack x y.
Lemma pack1_additive : forall (x: X) (y1 y2 : Y), pack1 x (vadd y1 y2) = vadd (pack1 x y1) (pack1 x y2).
intros.
unfold pack1, pack.
simpl.
unfold zadd.
apply zeq.
auto.
intros.
simpl.
rewrite linmap2_correct.
rewrite linmap_additive.
rewrite <- linmap2_correct.
rewrite <- linmap2_correct.
auto.
Qed.
Lemma pack1_homogen : forall (x: X) (c: F) (y: Y), pack1 x (smul c y) = smul c (pack1 x y).
intros.
unfold pack1.
unfold pack.
simpl.
unfold zsmul.
simpl.
apply zeq.
auto.
intros.
simpl.
rewrite linmap2_correct.
rewrite linmap_homogen.
rewrite <- linmap2_correct.
auto.
Qed.
Definition pack1_lm (x: X) : LinMap (U := Y) (V := Z) :=
{|
linmap := pack1 x;
linmap_additive := pack1_additive x;
linmap_homogen := pack1_homogen x;
|}.
Lemma pack2_additive : forall (y : Y) (x1 x2: X) , pack2 y (vadd x1 x2) = vadd (pack2 y x1) (pack2 y x2).
intros.
unfold pack2.
unfold pack.
simpl.
unfold zadd.
apply zeq.
auto.
intros.
simpl.
rewrite linmap1_correct.
rewrite linmap_additive.
rewrite <- linmap1_correct.
rewrite <- linmap1_correct.
auto.
Qed.
Lemma pack2_homogen : forall (y: Y) (c: F) (x: X), pack2 y (smul c x) = smul c (pack2 y x).
intros.
unfold pack2.
unfold pack.
simpl.
unfold zsmul.
simpl.
apply zeq.
auto.
intros.
simpl.
rewrite linmap1_correct.
rewrite linmap_homogen.
rewrite <- linmap1_correct.
auto.
Qed.
Definition pack2_lm (y: Y) : LinMap (U := X) (V := Z) :=
{|
linmap := pack2 y;
linmap_additive := pack2_additive y;
linmap_homogen := pack2_homogen y;
|}.
Lemma pack2_lm_correct : forall (x: X) (y: Y), pack x y = linmap (pack2_lm y) x.
intros.
auto.
Qed.
Lemma pack1_lm_correct : forall (x: X) (y: Y), pack x y = linmap (pack1_lm x) y.
intros.
auto.
Qed.
(* pack is in fact a bi-linear map *)
Definition pack_bm : BilinearMap Z :=
{|
blmap := pack;
linmap1 := pack2_lm;
linmap2 := pack1_lm;
linmap1_correct := pack2_lm_correct;
linmap2_correct := pack1_lm_correct;
|}.
Definition zmap {U: Type} `{VectorSpace F U} (bm: BilinearMap U) (z: Z) : U :=
match z with
| make_z m => m U _ _ bm
end.
Lemma zmap_unz : forall {U: Type} `{VectorSpace F U} (bm: BilinearMap U) (z: Z),
zmap bm z = unz z bm.
intros.
unfold zmap, unz.
auto.
Qed.
Lemma zmap_additive : forall {U: Type} `{VectorSpace F U} (bm: BilinearMap U) (z1 z2 : Z),
zmap bm (vadd z1 z2) = vadd (zmap bm z1) (zmap bm z2).
intros.
simpl.
rewrite zmap_unz.
auto.
Qed.
Lemma zmap_homogen : forall {U: Type} `{VectorSpace F U} (bm: BilinearMap U) (c: F) (z: Z),
zmap bm (smul c z) = smul c (zmap bm z).
auto.
Qed.
Definition trmap {U: Type} `{VectorSpace F U} (bm: BilinearMap U) : LinMap (F := F) (U := Z) (V := U) :=
{|
linmap := zmap bm;
linmap_additive := zmap_additive bm;
linmap_homogen := zmap_homogen bm;
|}.
Lemma trmap_correct : forall {U: Type} `{VectorSpace F U} (bm: BilinearMap U) (x: X) (y: Y),
blmap _ bm x y = linmap (trmap bm) (pack x y).
auto.
Qed.
End TensorProduct.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment