Last active
May 3, 2023 02:26
-
-
Save knikolla/604797eb2972f0f7c2692ded076a3be8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Introduction\n", | |
| "We are trying to model a tree. Each node of the tree inherits the attributes of the parent node.\n", | |
| "All levels of the tree should support the same abstractions and be treated the same.\n", | |
| "\n", | |
| "Ex.\n", | |
| "- Node1 (Institution = 'Boston University')\n", | |
| "- Node2 child of Node1 (Institution = 'Boston University', Field of Science = 'Computer Science)\n", | |
| "- Node3 child of Node2 (Institution = 'Boston University', Field of Science = 'Computer Science, PI='[email protected]')\n", | |
| "\n", | |
| "The records are therefore a collection of the traversal from root to leaf.\n", | |
| "\n", | |
| "We can map this in Python and SQL Alchemy by associating the Node with a parent and implementing a function that traverses the tree." | |
| ], | |
| "metadata": { | |
| "collapsed": false | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 487, | |
| "outputs": [], | |
| "source": [ | |
| "import datetime\n", | |
| "import uuid\n", | |
| "\n", | |
| "from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, Session\n", | |
| "from sqlalchemy import create_engine, Column, DateTime, ForeignKey, Integer, String, select\n", | |
| "from sqlalchemy.sql import func\n", | |
| "\n", | |
| "\n", | |
| "def default_id():\n", | |
| " \"\"\"Returns the default ID to assign to nodes when one isn't specified.\"\"\"\n", | |
| " return uuid.uuid4().hex\n", | |
| "\n", | |
| "\n", | |
| "# SQLAlchemy requires us to define a Base which it uses to keep track of all\n", | |
| "# the created models.\n", | |
| "class Base(DeclarativeBase):\n", | |
| " pass\n", | |
| "\n", | |
| "class HierarchyNode(Base):\n", | |
| " __tablename__ = \"hierarchy_nodes\"\n", | |
| "\n", | |
| " id: Mapped[int] = Column(Integer, primary_key=True)\n", | |
| " node_id: Mapped[str] = Column(String, default=default_id)\n", | |
| " parent_id = Column(Integer, ForeignKey('hierarchy_nodes.node_id'), nullable=True)\n", | |
| " created_at: Mapped[datetime.datetime] = Column(DateTime, default=func.now())\n", | |
| " node_type: Mapped[str]\n", | |
| " value: Mapped[str]\n", | |
| " display_value: Mapped[str]\n", | |
| " status: Mapped[str] = Column(String, default='Active')\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def create_node(cls, session, **kwargs):\n", | |
| " \"\"\"Creates and inserts a new node in the tree.\"\"\"\n", | |
| " node = HierarchyNode(**kwargs)\n", | |
| " session.add(node)\n", | |
| " session.commit()\n", | |
| " return node\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def get_node(cls, session, node_id=None, parent_id=None, value=None, before=None):\n", | |
| " \"\"\"Retrieves a node from the tree based on node_id, value or time.\"\"\"\n", | |
| " q = select(cls)\n", | |
| " if node_id:\n", | |
| " q = q.where(cls.node_id == node_id)\n", | |
| " if parent_id:\n", | |
| " q = q.where(cls.parent_id == parent_id)\n", | |
| " if value:\n", | |
| " q = q.where(cls.value == value)\n", | |
| " if before:\n", | |
| " q = q.filter(cls.created_at <= before)\n", | |
| " q = q.order_by(cls.id.desc()).limit(1)\n", | |
| " return session.scalar(q)\n", | |
| "\n", | |
| " @property\n", | |
| " def parent(self):\n", | |
| " \"\"\"Returns the most updated parent of the node.\"\"\"\n", | |
| " if self.parent_id:\n", | |
| " return self.get_node(Session.object_session(self),\n", | |
| " node_id=self.parent_id)\n", | |
| "\n", | |
| " def get(self, node_type):\n", | |
| " \"\"\"Traverses the tree searching for a matching node_type.\n", | |
| "\n", | |
| " Ex. to query the 'institution' of a 'PI' when the 'PI' is a\n", | |
| " child of the institution.\n", | |
| " \"\"\"\n", | |
| " if self.node_type == node_type:\n", | |
| " return self\n", | |
| " elif self.parent is None:\n", | |
| " raise KeyError\n", | |
| " else:\n", | |
| " self.parent.get(node_type)\n", | |
| "\n", | |
| " def set(self, **kwargs):\n", | |
| " \"\"\"Creates a new copy of the Node with changed values.\"\"\"\n", | |
| " with Session.object_session(self) as session:\n", | |
| " new_node = HierarchyNode(\n", | |
| " node_id=self.node_id,\n", | |
| " parent_id=kwargs.get('parent_id') or self.parent_id,\n", | |
| " node_type=self.node_type,\n", | |
| " value=kwargs.get('value') or self.value,\n", | |
| " display_value=kwargs.get('display_value') or self.display_value,\n", | |
| " status=kwargs.get('status') or self.status\n", | |
| " )\n", | |
| " session.add(new_node)\n", | |
| " session.commit()\n", | |
| "\n", | |
| " def to_dict(self, context=None):\n", | |
| " \"\"\"Aggregates all the attributes of a node and all its parents.\"\"\"\n", | |
| " if not context:\n", | |
| " context = {'status': self.status}\n", | |
| " context.update({self.node_type: self.display_value})\n", | |
| "\n", | |
| " if self.parent:\n", | |
| " self.parent.to_dict(context)\n", | |
| "\n", | |
| " return context\n", | |
| "\n", | |
| "\n", | |
| "# Fire it up using an in-memory sqlite database\n", | |
| "engine = create_engine(\"sqlite://\")\n", | |
| "Base.metadata.create_all(engine)" | |
| ], | |
| "metadata": { | |
| "collapsed": false, | |
| "ExecuteTime": { | |
| "start_time": "2023-05-02T22:23:58.076711Z", | |
| "end_time": "2023-05-02T22:23:58.081317Z" | |
| } | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 488, | |
| "outputs": [], | |
| "source": [ | |
| "from sqlalchemy.orm import Session\n", | |
| "\n", | |
| "with Session(engine) as session:\n", | |
| " # Create some example Nodes\n", | |
| " # - bu.edu\n", | |
| " # -- bu.edu.cs\n", | |
| " # -- [email protected]\n", | |
| " # - harvard.edu\n", | |
| " # -- harvard.edu cs\n", | |
| " bu = HierarchyNode.create_node(\n", | |
| " session,\n", | |
| " node_type='institution',\n", | |
| " value='bu.edu',\n", | |
| " display_value='Boston University'\n", | |
| " )\n", | |
| "\n", | |
| " cs = HierarchyNode.create_node(\n", | |
| " session,\n", | |
| " node_type='field_of_science',\n", | |
| " parent_id=bu.node_id,\n", | |
| " value='bu.edu cs',\n", | |
| " display_value='Computer Science'\n", | |
| " )\n", | |
| "\n", | |
| " harvard = HierarchyNode.create_node(\n", | |
| " session,\n", | |
| " node_type='institution',\n", | |
| " value='harvard.edu',\n", | |
| " display_value='Harvard University'\n", | |
| " )\n", | |
| "\n", | |
| " harvard_cs = HierarchyNode.create_node(\n", | |
| " session,\n", | |
| " node_type='field_of_science',\n", | |
| " parent_id=harvard.node_id,\n", | |
| " value='harvard.edu cs',\n", | |
| " display_value='Harvard Computer Science'\n", | |
| " )\n", | |
| "\n", | |
| " foo = HierarchyNode.create_node(\n", | |
| " session,\n", | |
| " node_type='pi',\n", | |
| " parent_id=cs.node_id,\n", | |
| " value='[email protected]',\n", | |
| " display_value='Foo Bar <[email protected]>'\n", | |
| " )" | |
| ], | |
| "metadata": { | |
| "collapsed": false, | |
| "ExecuteTime": { | |
| "start_time": "2023-05-02T22:23:58.083997Z", | |
| "end_time": "2023-05-02T22:23:58.092106Z" | |
| } | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 489, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "== This is the initial user before updates == \n", | |
| "{'status': 'Active', 'pi': 'Foo Bar <[email protected]>', 'field_of_science': 'Computer Science', 'institution': 'Boston University'}\n", | |
| "== This is the user after updating the parent ==\n", | |
| "{'status': 'Active', 'pi': 'Foo Bar <[email protected]>', 'field_of_science': 'Harvard Computer Science', 'institution': 'Harvard University'}\n", | |
| "== This is the user after updating the value and display value of the parent ==\n", | |
| "{'status': 'Active', 'pi': 'Foo Bar <[email protected]>', 'field_of_science': 'Computer Science and AI', 'institution': 'Harvard University'}\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "with Session(engine) as session:\n", | |
| " # We retrieve our user and we update their field of science to Harvard's CS.\n", | |
| " # For this we need to retrieve Harvard's CS.\n", | |
| " foo_user = HierarchyNode.get_node(session, value='[email protected]')\n", | |
| " node_id_foo_user = foo_user.node_id\n", | |
| " print(\"== This is the initial user before updates == \")\n", | |
| " print(foo_user.to_dict())\n", | |
| "\n", | |
| " node_id_harvard_cs = HierarchyNode.get_node(session, value='harvard.edu cs').node_id\n", | |
| " updated = foo_user.set(parent_id=node_id_harvard_cs)\n", | |
| "\n", | |
| " # Retrieving the user again with the same value or node id gives us the new user in both cases\n", | |
| " foo_user_2 = HierarchyNode.get_node(session, node_id=node_id_foo_user)\n", | |
| " foo_user_3 = HierarchyNode.get_node(session, value='[email protected]')\n", | |
| " assert node_id_foo_user == foo_user_2.node_id == foo_user_3.node_id\n", | |
| " assert foo_user_2.parent_id == foo_user_3.parent_id == node_id_harvard_cs\n", | |
| " print(\"== This is the user after updating the parent ==\")\n", | |
| " print(foo_user_2.to_dict())\n", | |
| "\n", | |
| " # If we update the field of science display name of Harvard, the user shows the correct parent\n", | |
| " harvard_cs = HierarchyNode.get_node(session, value='harvard.edu cs')\n", | |
| " harvard_cs.set(value='harvard.edu CS/AI', display_value='Computer Science and AI')\n", | |
| " foo_user_2 = HierarchyNode.get_node(session, node_id=node_id_foo_user)\n", | |
| " assert foo_user_2.parent.value == 'harvard.edu CS/AI'\n", | |
| " print(\"== This is the user after updating the value and display value of the parent ==\")\n", | |
| " print(foo_user_2.to_dict())\n" | |
| ], | |
| "metadata": { | |
| "collapsed": false, | |
| "ExecuteTime": { | |
| "start_time": "2023-05-02T22:23:58.093731Z", | |
| "end_time": "2023-05-02T22:23:58.101204Z" | |
| } | |
| } | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment