Skip to content

Instantly share code, notes, and snippets.

@svenvarkel
Created October 25, 2023 07:49
Show Gist options
  • Select an option

  • Save svenvarkel/1936f343eb20a90fc94abc1cc2c4e0e5 to your computer and use it in GitHub Desktop.

Select an option

Save svenvarkel/1936f343eb20a90fc94abc1cc2c4e0e5 to your computer and use it in GitHub Desktop.
BeanieInfiniteRecursionTestCase
import asyncio
import unittest
from enum import Enum
from datetime import datetime, UTC
from typing import ClassVar, Optional, List
from pymongo import IndexModel, ASCENDING
from pydantic import BaseModel, Field, Extra
from beanie import Document, BackLink, Link, init_beanie
def utcnow():
"""
This returns current datetime in UTC timezone
"""
return datetime.now(UTC)
class EntityType(Enum):
DEAL = "deal"
ASSET = "asset"
LEASE = "lease"
ABSTRACT = "abstract"
class AbstractEntity(Document, BaseModel):
# class vars
_entity_type: ClassVar[EntityType] = EntityType.ABSTRACT
created_at: datetime = Field(default=None)
updated_at: datetime = Field(default_factory=utcnow)
class Config:
arbitrary_types_allowed = True
extra = Extra.allow
populate_by_name = True
use_enum_values = False
class Settings:
name = "_abstract_entity"
extra = Extra.allow
merge_indexes = False
use_cache = False
indexes = [
IndexModel(
[
("_id", ASCENDING),
("revision_id", ASCENDING),
],
unique=True,
),
# IndexModel([("UID", ASCENDING)], unique=True),
IndexModel([("created_at", ASCENDING)]),
IndexModel([("updated_at", ASCENDING)]),
]
class Deal(AbstractEntity):
# class vars
_entity_type: ClassVar[EntityType] = EntityType.DEAL
# fields
address: Optional[str] = Field(default=None)
asset: Optional[BackLink["Asset"]] = Field(original_field="deal")
class Settings:
name = "deal"
indexes = AbstractEntity.Settings.indexes + [
IndexModel([("address", ASCENDING)]),
]
class Asset(AbstractEntity):
# class vars
_entity_type: ClassVar[EntityType] = EntityType.ASSET
# fields
name: Optional[str] = Field(default=None)
deal: Optional[Link[Deal]] = Field(default=None)
leases: Optional[List[Link["Lease"]]] = Field(default_factory=list)
class Settings:
name = "asset"
indexes = AbstractEntity.Settings.indexes + [
IndexModel([("name", ASCENDING)]),
]
class Lease(AbstractEntity):
# class vars
_entity_type: ClassVar[EntityType] = EntityType.LEASE
# fields
asset: BackLink[Asset] = Field(original_field="leases")
unit: Optional[str] = Field(default=None)
class Settings:
name = "lease"
indexes = AbstractEntity.Settings.indexes + []
class TestData:
DEALS = [
{
"address": "1234 Elm Street, Springfield, IL 62701",
"created_at": "2021-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
},
{
"address": "5678 Maple Avenue, Portland, OR 97204",
"created_at": "2022-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
},
{
"address": "9876 Oak Drive, Dallas, TX 75201",
"created_at": "2023-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
},
]
ASSETS = [
{
"name": "1234 Elm Street, Springfield, IL 62701",
"created_at": "2021-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
},
{
"name": "5678 Maple Avenue, Portland, OR 97204",
"created_at": "2022-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
},
{
"name": "9876 Oak Drive, Dallas, TX 75201",
"created_at": "2023-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
},
]
LEASES = [
{
"name": "1234 Elm Street, Springfield, IL 62701",
"created_at": "2021-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
"unit": "1",
},
{
"name": "1234 Elm Street, Springfield, IL 62701",
"created_at": "2021-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
"unit": "2",
},
{
"name": "1234 Elm Street, Springfield, IL 62701",
"created_at": "2021-01-01T00:00:00.000000+00:00",
"updated_at": utcnow(),
"unit": "3",
},
]
class BeanieInfiniteRecursionTestCase(unittest.IsolatedAsyncioTestCase):
"""
This test case demonstrates infinite recursion when fetching links
"""
async def init_beanie(self):
"""
Initialize Beanie
Returns:
"""
await init_beanie(
connection_string="mongodb://dev:dev@localhost:27017/beanie_test?authSource=admin",
document_models=[Deal, Asset, Lease],
)
async def insert_data(self):
"""
Insert test data
Returns:
"""
for deal_data in TestData.DEALS:
deal = Deal(**deal_data)
await deal.save()
for asset_data in TestData.ASSETS:
if deal_data["address"] == asset_data["name"]:
asset = Asset(**asset_data)
asset.deal = deal
for lease_data in TestData.LEASES:
if asset_data["name"] == lease_data["name"]:
lease = Lease(**lease_data)
lease.asset = asset
await lease.save()
asset.leases.append(lease)
await asset.save()
async def asyncSetUp(self):
"""
Set up the test case
Returns:
"""
await self.init_beanie()
await self.insert_data()
async def asyncTearDown(self):
"""
Tear down the test case
Returns:
"""
# await Deal.delete_all()
# await Asset.delete_all()
# await Lease.delete_all()
pass
async def test_find(self):
"""
Test find and inspect the results by setting a breakpoint here
Returns:
"""
deals = await Deal.find(fetch_links=True).to_list()
# FIXME stop debugger here and inspect deals
# then continue and see (almost) infinite recursion under Deal->Asset->Deal->Asset.
# It stops at some point but it's still more levels than needed
print(deals)
if __name__ == "__main__":
asyncio.run(unittest.main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment