Last active
February 10, 2026 03:45
-
-
Save LiutongZhou/139eddad3ea732b9c5f06b97cf702799 to your computer and use it in GitHub Desktop.
Jax Distributed Zero to Hero
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "8bfe2b28-24e4-4fa1-a761-79cf1985b09b", | |
| "metadata": { | |
| "editable": true, | |
| "slideshow": { | |
| "slide_type": "" | |
| }, | |
| "tags": [] | |
| }, | |
| "source": [ | |
| "# JAX Sharding and Parallel Computing - All-In-One Guide\n", | |
| "From Zero to Hero (v0.8+)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "393e9dce-d18b-40b7-932b-07b2e32b83a4", | |
| "metadata": {}, | |
| "source": [ | |
| "**TOC**\n", | |
| "1. [Imports](#Imports)\n", | |
| "2. [Create Device Mesh and Sharding Spec](#Create-Device-Mesh-and-Sharding-Spec)\n", | |
| "3. [Explicit Sharding](#Explicit-Sharding)\n", | |
| "4. [Manual Sharding](#Manual-Sharding)\n", | |
| " 1. [Simple Matrix Multiplication](#Simple-Matrix-Multiplication)\n", | |
| " 2. [The-Collective-Communication-Ops](#The-Collective-Communication-Ops)\n", | |
| " 1. [PPermute](#PPermute)\n", | |
| " 2. [Reduce-Scatter](#Reduce-Scatter)\n", | |
| " 3. [All-Gather](#All-Gather)\n", | |
| " 4. [All-Reduce](#All-Reduce)\n", | |
| " 5. [All-to-All](#All-to-All)\n", | |
| "5. [Matrix Multiplication](#Matrix-Multiplication)\n", | |
| " 1. [Naive Version](#Naive-Version)\n", | |
| " 2. [FSDP Forward](#FSDP-Forward)\n", | |
| " 3. [Profiling `matmul_fsdp`](#Profiling-matmul_fsdp)\n", | |
| " 4. [`ppermute` Loop Version](#ppermute-Loop-Version)\n", | |
| " 5. [Matmul Reduce Scatter](#Matmul-Reduce-Scatter)\n", | |
| "6. [Practical Examples](#Practical-Examples)\n", | |
| " 1. [Distributed Training of Neural Networks](#Distributed-Training-of-Neural-Networks)\n", | |
| " 1. [Data Parallelism](#Data-Parallelism)\n", | |
| " 2. [FSDP](#FSDP)\n", | |
| " 3. [TP](#TP)\n", | |
| " 4. [FSDP + TP](#FSDP-+-TP)\n", | |
| " 5. [MOE Parallel with Token Dropping (Advanced)](#MOE-Parallel-with-Token-Dropping)\n", | |
| " 2. [Knowledge Distillation (Naive Pipeline Parallelism)](#Knowledge-Distillation) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "566d4b26-f47f-4345-b058-1519f9118c10", | |
| "metadata": {}, | |
| "source": [ | |
| "## Imports" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "e1c0821c-c6aa-430f-b7f2-502db736029e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import itertools\n", | |
| "from functools import partial\n", | |
| "from typing import Iterator\n", | |
| "\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "from jax.sharding import (\n", | |
| " AxisType,\n", | |
| " Mesh,\n", | |
| " NamedSharding,\n", | |
| " PartitionSpec as P,\n", | |
| " auto_axes,\n", | |
| " explicit_axes,\n", | |
| " get_abstract_mesh,\n", | |
| " reshard,\n", | |
| ")\n", | |
| "from jaxtyping import Array, Float, Int, PyTree" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "f0d55a5e-370b-4788-8247-0e174f7197b7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[CpuDevice(id=0),\n", | |
| " CpuDevice(id=1),\n", | |
| " CpuDevice(id=2),\n", | |
| " CpuDevice(id=3),\n", | |
| " CpuDevice(id=4),\n", | |
| " CpuDevice(id=5),\n", | |
| " CpuDevice(id=6),\n", | |
| " CpuDevice(id=7)]" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "USE_CPU = True # Set to True to use CPU to simulate multiple GPUs/TPUs\n", | |
| "\n", | |
| "if USE_CPU:\n", | |
| " if not jax._src.xla_bridge.backends_are_initialized():\n", | |
| " jax.config.update(\"jax_platforms\", \"cpu\")\n", | |
| " jax.config.update(\"jax_num_cpu_devices\", 8)\n", | |
| " # Alternatively\n", | |
| " # os.environ['JAX_PLATFORMS'] = 'cpu'\n", | |
| " # os.environ['XLA_FLAGS'] = \"--xla_force_host_platform_device_count=8\"\n", | |
| "jax.devices()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "418b59ac-0f13-41a3-9ad3-3601339bdbd6", | |
| "metadata": {}, | |
| "source": [ | |
| "## Create Device Mesh and Sharding Spec" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "390707ba-f4da-4972-94b5-1acef60dcaca", | |
| "metadata": {}, | |
| "source": [ | |
| "In JAX, a device `Mesh` is an organized logical view of the available physical devices.\n", | |
| "\n", | |
| "If you have multiple (GPU/TPU) devices, you can arrange them into a device `Mesh` with named axes for reference. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "641139cf-d4b3-4a9c-a2c0-e72a34aa8efd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Current mesh is: AbstractMesh('dp': 4, 'tp': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)\n", | |
| "Current mesh is: AbstractMesh((), axis_types=())\n", | |
| "Current mesh is: AbstractMesh('dp': 4, 'tp': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Device Mesh = named axes over the device set -> dp x tp devices\n", | |
| "global_mesh = jax.make_mesh(\n", | |
| " (4, 2),\n", | |
| " (\"dp\", \"tp\"),\n", | |
| " axis_types=(AxisType.Explicit, AxisType.Explicit),\n", | |
| " # AxisType.Explicit is the default in v0.9+ (New pattern)\n", | |
| " # Only use AxisType.Explicit or AxisType.Manual. Do not use AxisType.Auto (legacy default) for production\n", | |
| ")\n", | |
| "\n", | |
| "with jax.set_mesh(global_mesh):\n", | |
| " print(f\"Current mesh is: {get_abstract_mesh()}\")\n", | |
| "\n", | |
| "print(f\"Current mesh is: {get_abstract_mesh()}\")\n", | |
| "jax.set_mesh(\n", | |
| " global_mesh\n", | |
| ") # <- be careful: this will set a global mesh every where throughout the process\n", | |
| "print(f\"Current mesh is: {get_abstract_mesh()}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "71acc0fc-ae71-4007-8338-589621a98866", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Define Tensor Sharding Spec\n", | |
| "\n", | |
| "## Data Parallel: P('dp', None)\n", | |
| "## Split tensor along axis 0 and put partitions over device mesh axis 'dp',\n", | |
| "## replicate tensor along axis 1 to put replicas over device mesh 'tp'\n", | |
| "shard_dp = NamedSharding(global_mesh, P(\"dp\", None))\n", | |
| "## Vice versa\n", | |
| "shard_tp = NamedSharding(global_mesh, P(None, \"tp\")) # tensor parallel\n", | |
| "shard_fully_replicate = NamedSharding(global_mesh, P()) # fully replicated" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "74664a76-43dd-4753-a2c5-c177a01eaa75", | |
| "metadata": {}, | |
| "source": [ | |
| "You can create many different Meshes and Sharding Specs. \n", | |
| "\n", | |
| "Technically, you can create meshes that share GPUs (e.g. `GPU:0` in both `mesh_a` and `mesh_b`). \n", | |
| "\n", | |
| "A use case of having two meshes: see [Practical Examples - **Knowledge Distillation**](#Knowledge-Distillation)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f59789a6-b434-4ae8-9e91-63c82baa16b9", | |
| "metadata": {}, | |
| "source": [ | |
| "## Explicit Sharding" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1507f09f-2ef9-4d81-beef-ca1cb4260e11", | |
| "metadata": {}, | |
| "source": [ | |
| "Given a 2D device mesh with axis names `(\"dp\", \"tp\")`\n", | |
| "* Notation: `float32[1024@dp,4096]`\n", | |
| "* Interpretation:\n", | |
| " * a tensor of `dtype=float32` and of shape `(1024, 4096)`\n", | |
| " * sharded (split) along `dim=0` into `dp` partitions and placed across the device mesh axis `dp`\n", | |
| " * replicated along `dim=1` into `tp` replicas and placed across the device mesh axis `tp`\n", | |
| "\n", | |
| "Similarly, given a 6D device mesh with axis names `(\"pp\", \"dp\", \"fsdp\", \"sp\", \"tp\", \"ep\")` \n", | |
| "* Notation: `bfloat16[512, 2048@(sp, tp), 128_000]`\n", | |
| "* Interpretation:\n", | |
| " * a tensor of `dtype=bfloat16` and of shape `(512, 2048, 128_000)`\n", | |
| " * sharded (split) along `dim=1` into `sp x tp` partitions and placed across the device mesh axes `sp x tp`\n", | |
| " * replicated along `dim=0` and `dim=2` into `pp x dp x fsdp x ep` replicas and placed across those device mesh axes" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "073ebb9a-afc4-4356-b614-e995bf753f40", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x)=ShapedArray(float32[1024@dp,4096])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = jax.random.normal(jax.random.key(0), (1024, 4096))\n", | |
| "# put tensor from default device (CPU) to accelerator device\n", | |
| "# and explicitly shard x along axis 0 to distribute partitions across mesh axis dp\n", | |
| "x = jax.device_put(x, shard_dp)\n", | |
| "print(f\"{jax.typeof(x)=}\") # float32[1024@dp,4096]\n", | |
| "# jax.debug.visualize_array_sharding(x) # to see how x is sharded horizontally" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "2a1e26ab-8e7e-495e-8d25-47780821afbc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w)=ShapedArray(float32[4096,256@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "w = jax.random.normal(jax.random.key(1), (4096, 256))\n", | |
| "# explicitly shard model weight w along axis 1 to distribute partition across mesh axis tp\n", | |
| "w = jax.device_put(w, shard_tp)\n", | |
| "print(f\"{jax.typeof(w)=}\") # float32[4096,256@tp]\n", | |
| "# jax.debug.visualize_array_sharding(w) # to see how w is sharded vertically" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "1f4a8be4-e807-4b05-8e85-eab9b6c5ae9f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Current mesh is: AbstractMesh('dp': 4, 'tp': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)\n", | |
| "\n", | |
| "Shard of x: NamedSharding(mesh=Mesh('dp': 4, 'tp': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('dp',), memory_kind=device)\n", | |
| "Shard of w: NamedSharding(mesh=Mesh('dp': 4, 'tp': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'tp'), memory_kind=device)\n", | |
| "\n", | |
| "jax.typeof(y)=ShapedArray(float32[1024@dp,256@tp])\n", | |
| "jax.typeof(y_resharded)=ShapedArray(float32[1024,256])\n", | |
| "jax.typeof(y_resharded_2)=ShapedArray(float32[1024@tp,256@dp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "def f(x: Float[Array, \"batch hidden\"], w: Float[Array, \"hidden out_dim\"]) -> tuple[\n", | |
| " Float[Array, \"batch out_dim\"],\n", | |
| " Float[Array, \"batch out_dim\"],\n", | |
| " Float[Array, \"batch out_dim\"],\n", | |
| "]:\n", | |
| " # float32[1024@dp,4096] @ float32[4096,256@tp] -> (1024@dp, 256@tp)\n", | |
| " jax.debug.inspect_array_sharding(x, callback=lambda x: print(f\"\\nShard of x: {x}\"))\n", | |
| " jax.debug.inspect_array_sharding(w, callback=lambda w: print(f\"Shard of w: {w}\\n\"))\n", | |
| " # ^ use this to debug inside jit\n", | |
| " y = x @ w\n", | |
| "\n", | |
| " print(f\"Current mesh is: {get_abstract_mesh()}\")\n", | |
| " # Note: for explicit axis type, with_sharding_constraint acts like an assert check\n", | |
| " y = jax.lax.with_sharding_constraint(y, NamedSharding(global_mesh, P(\"dp\", \"tp\")))\n", | |
| " # Only in auto axis type, it (implicitly) acts like a reshard.\n", | |
| "\n", | |
| " # You should explicitly use reshard to change tensor layout, if needed (to be wary of the communication cost)\n", | |
| " y_resharded = reshard(y, P()) # <- all gather will happen\n", | |
| " y_resharded_2 = reshard(y, P(\"tp\", \"dp\")) # <- all to all will happen\n", | |
| " return y, y_resharded, y_resharded_2\n", | |
| "\n", | |
| "\n", | |
| "y, y_resharded, y_resharded_2 = f(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(f\"{jax.typeof(y_resharded)=}\")\n", | |
| "print(f\"{jax.typeof(y_resharded_2)=}\")\n", | |
| "assert jnp.all(y == y_resharded)\n", | |
| "assert jnp.all(y_resharded == y_resharded_2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "28d2d434-a823-4c43-b729-4e83b084a181", | |
| "metadata": {}, | |
| "source": [ | |
| "## Manual Sharding\n", | |
| "Using `jax.shard_map` (+ `jax.lax.*` collectives) is highly preferred for production code because of its explicitness. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "94a4f592-1298-4ee9-81db-d03bc34f43f6", | |
| "metadata": {}, | |
| "source": [ | |
| "### Simple Matrix Multiplication" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9664e50e-f072-4b0d-ba87-2edd81ae40ab", | |
| "metadata": {}, | |
| "source": [ | |
| "Given a device mesh Mesh('dp': 4, 'tp': 2)\n", | |
| "* Notation: `float32[256{dp=4}, 4096]`\n", | |
| "* Interpretation:\n", | |
| " * a local tensor block of dtype `float32` and shape `(256, 4096)`\n", | |
| " * varying along `dim=0` across the device mesh axis `dp` and replicated across the device mesh axis `tp`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "4eaac26d-58dd-4d35-bb6c-9e5b93cfc3cd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(float32[256,4096]{dp})\n", | |
| "jax.typeof(w_block)=ShapedArray(float32[4096,128]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(float32[256,128]{tp,dp})\n", | |
| "jax.typeof(y)=ShapedArray(float32[1024@dp,256@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = jax.random.normal(jax.random.key(0), (1024, 4096))\n", | |
| "w = jax.random.normal(jax.random.key(1), (4096, 256))\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit # Compile the manual parallel code. For easier debugging, you can comment out jit compilation\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(None, \"tp\")), out_specs=P(\"dp\", \"tp\")\n", | |
| ")\n", | |
| "def matmul(\n", | |
| " x_block: Float[Array, \"dp_block hidden\"], w_block: Float[Array, \"hidden tp_block\"]\n", | |
| ") -> Float[Array, \"dp_block tp_block\"]:\n", | |
| " \"\"\"This function will run in parallel (concurrently) across devices\n", | |
| "\n", | |
| " All the (dp=4, tp=2) devices will be busy running\n", | |
| " \"\"\"\n", | |
| " # f32[1024@dp=4, 4096] @ f32[4096, 256@tp=2]\n", | |
| " # f32[256{dp=4}, 4096] @ f32[4096, 128{tp=2}}\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " print(f\"{jax.typeof(w_block)=}\")\n", | |
| " y_block = x_block @ w_block # float32[256{dp=4},128{tp=2}]\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " # the varying local block y_block will be concatenated per output sharding spec: P(\"dp\", \"tp\")\n", | |
| " # concatenate y_block indexed by dp axis along dim=0 and indexed by tp axis along dim=1 to form the final output\n", | |
| " # float32[1024@dp,256@tp]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = matmul(x, w) # with shard_map, no need to do device_put\n", | |
| "print(f\"{jax.typeof(y)=}\") # float32[1024@dp,256@tp]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "afe14164-2215-4c3e-97ca-92f8b16f81b6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "13.4 ms ± 260 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "def f(x, w):\n", | |
| " return x @ w\n", | |
| "\n", | |
| "\n", | |
| "%timeit jax.block_until_ready(f(x,w)) # non-parallel code is slow" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "0e9d55c3-38af-486f-8615-aab8213474be", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.44 ms ± 79 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(matmul(x,w)) # parallel code is faster" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "bdc9ed88-89b7-4eaf-aa24-de43a0a68ac5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert jnp.allclose(reshard(y, P()), x @ w, atol=1e-4) # verify correctness" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "edbc2ef5-77c1-4f93-8d64-da17c9427912", | |
| "metadata": {}, | |
| "source": [ | |
| "### The Collective Communication Ops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2f66b5db-de48-4487-84df-d879f1d7558c", | |
| "metadata": {}, | |
| "source": [ | |
| "#### PPermute \n", | |
| "Point-to-point communication.\n", | |
| "\n", | |
| "Communication cost: $O(n)$ where $n$ is the number of devices along the device mesh axis (`jax.lax.axis_size(\"i\")`).\n", | |
| "\n", | |
| "Use case: pipeline parallelism send activations. Ring attention" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "1b97435a-c146-4818-83b3-885fd4ad4e02", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def ring_shift_simple_version(block: Array, axis_name: str, offset: int = 1) -> Array:\n", | |
| " \"\"\"Ring shift forward or backward along the device mesh axis\n", | |
| "\n", | |
| " Parameters\n", | |
| " --------\n", | |
| " block : Array\n", | |
| " axis_name: str\n", | |
| " offset : int\n", | |
| " If positive, ring shift forward; if negative, ring shift backward.\n", | |
| " Default is shift forward with step size=1\n", | |
| "\n", | |
| " Returns\n", | |
| " ----------\n", | |
| " receive : Array\n", | |
| " the received block from previous device if shift forward\n", | |
| " the received block from next device if shift backward\n", | |
| " \"\"\"\n", | |
| " num_devices = jax.lax.axis_size(axis_name)\n", | |
| " # number of devices along the device mesh axis\n", | |
| " src_to_dst = [(idx, (idx + offset) % num_devices) for idx in range(num_devices)]\n", | |
| " return jax.lax.ppermute(block, axis_name, perm=src_to_dst)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "489562fc-4c6d-4d97-b9b3-84d751299ab7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[0 1 2 3]\n", | |
| " [4 5 6 7]]\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_prev)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_next)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[2@tp,4@dp])\n", | |
| "[[3 0 1 2]\n", | |
| " [7 4 5 6]]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[2@tp,4@dp])\n", | |
| "[[1 2 3 0]\n", | |
| " [5 6 7 4]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=P(\"tp\", \"dp\"), out_specs=(P(\"tp\", \"dp\"), P(\"tp\", \"dp\"))\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"tp_block dp_block\"],\n", | |
| ") -> Float[Array, \"tp_block dp_block\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[2@tp,4@dp] -> i32[1{tp=2},1{dp=4}] varying -> i32[2@tp, 4@dp]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| "\n", | |
| " axis_name = \"dp\"\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift_simple_version(x_block, axis_name, offset=1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_prev)=}\")\n", | |
| "\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift_simple_version(x_block, axis_name, offset=-1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_next)=}\")\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8).reshape(2, 4)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "81177087-9ad9-49da-9eab-a3a933dc2eb9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[0 1]\n", | |
| " [2 3]\n", | |
| " [4 5]\n", | |
| " [6 7]]\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_prev)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_next)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[4@dp,2@tp])\n", | |
| "[[6 7]\n", | |
| " [0 1]\n", | |
| " [2 3]\n", | |
| " [4 5]]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[4@dp,2@tp])\n", | |
| "[[2 3]\n", | |
| " [4 5]\n", | |
| " [6 7]\n", | |
| " [0 1]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=P(\"dp\", \"tp\"), out_specs=(P(\"dp\", \"tp\"), P(\"dp\", \"tp\"))\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"dp_block tp_block\"],\n", | |
| ") -> Float[Array, \"dp_block tp_block\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[4@dp,2@tp] -> i32[1{dp=4},1{tp=2}] varying -> i32[4@dp, 2@tp]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| "\n", | |
| " axis_name = \"dp\"\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift_simple_version(x_block, axis_name, offset=1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_prev)=}\")\n", | |
| "\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift_simple_version(x_block, axis_name, offset=-1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_next)=}\")\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8).reshape(4, 2)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3885b2c3-14d3-4e03-b457-bd7b60f0e9ce", | |
| "metadata": {}, | |
| "source": [ | |
| "<div class=\"alert alert-danger\">\n", | |
| "`ring_shift_simple_version` only works if we ppermute along a single device mesh axis. If we want to ppermute along multiple axes, it may return wrong results.\n", | |
| "</div>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "88ab9978-dc2b-4acf-86f9-3f05d3c07688", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0 1 2 3 4 5 6 7]\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[7 4 5 6 0 1 2 3]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[4 5 6 7 1 2 3 0]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=P((\"tp\", \"dp\")),\n", | |
| " out_specs=(P((\"tp\", \"dp\")), P((\"tp\", \"dp\"))),\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"n\"],\n", | |
| ") -> Float[Array, \"n\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[8@(tp,dp)] -> i32[1]{tp=2,dp=4} varying -> i32[8@(tp,dp)]\n", | |
| " axis_name = (\"tp\", \"dp\")\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift_simple_version(x_block, axis_name, offset=1)\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift_simple_version(x_block, axis_name, offset=-1)\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ecd004eb-bca2-4007-a72c-b516c7fc95af", | |
| "metadata": {}, | |
| "source": [ | |
| "<div class=\"alert alert-warning\">Why?</div>\n", | |
| "\n", | |
| "Assuming the device mesh is `Mesh(dp=2, tp=2)` and the input partition spec is `P((\"tp\", \"dp\"))` for `[0, 1, 2, 3]`, we would expect a ring shift backward to return `[1,2,3,0]` of type `i32[4@(tp,dp)]`. \n", | |
| "\n", | |
| "Firstly, `[0,1,2,3]` will be partitioned and placed onto\n", | |
| "```\n", | |
| "[[(dp=0,tp=0)(device_id=0) 0, (dp=0,tp=1)(device_id=1) 2],\n", | |
| " [(dp=1,tp=0)(device_id=2) 1, (dp=1,tp=1)(device_id=3) 3]]\n", | |
| "```\n", | |
| "according to the input partition spec `P((\"tp\", \"dp\"))` using last-axis-fast style (\"dp\" axis varies faster). \n", | |
| "\n", | |
| "However the linear index used by `(src, dst)` tuples in the `perm` argument of `jax.lax.ppermute(block, axis_names, perm=src_to_dst)` is referring to the device index (set by Mesh) rather than the logical index (comes with `P((\"tp\", \"dp\"))`).\n", | |
| "\n", | |
| "The `src_to_dst =[(idx, (idx + offset) % num_devices) for idx in range(num_devices)]` logic is not a problem if there is only one ppermute axis, but when you `ppermute` along multiple axes this logic implicitly messes up your data order (because the mesh axes order and logical axes order mismatch). \n", | |
| "\n", | |
| "The solution is to maintain a mapping from logical index to device index and use device index in `jax.lax.ppermute`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "7096fe4d-8152-4f6d-a20b-e6c07c70701d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# A generalized implementation\n", | |
| "def ring_shift(\n", | |
| " block: Array, axis_name: str | tuple[str, ...], offset: int = 1\n", | |
| ") -> Array:\n", | |
| " \"\"\"Ring shift forward or backward along the device mesh axis\n", | |
| "\n", | |
| " Parameters\n", | |
| " --------\n", | |
| " block : Array\n", | |
| " axis_name: str | tuple[str, ...]\n", | |
| " Single axis name or tuple of axis names for multi-axis ring shift\n", | |
| " offset : int\n", | |
| " If positive, ring shift forward; if negative, ring shift backward.\n", | |
| " Default is shift forward with step size=1\n", | |
| "\n", | |
| " Returns\n", | |
| " ----------\n", | |
| " receive : Array\n", | |
| " the received block from previous device if shift forward\n", | |
| " the received block from next device if shift backward\n", | |
| " \"\"\"\n", | |
| " if isinstance(axis_name, str):\n", | |
| " axis_name = (axis_name,)\n", | |
| " logical_axes = {a: jax.lax.axis_size(a) for a in axis_name}\n", | |
| " mesh = get_abstract_mesh()\n", | |
| " mesh_axes = {a: logical_axes[a] for a in mesh.axis_names if a in logical_axes}\n", | |
| " num_devices = jax.lax.axis_size(axis_name)\n", | |
| "\n", | |
| " def linear_idx_to_coordinates(idx: int, axes: dict[str, int]) -> dict[str, int]:\n", | |
| " \"\"\"Convert linear idx to coordinates according to the axes order\"\"\"\n", | |
| " coords: dict[str, int] = {}\n", | |
| " for axis_name, axis_size in reversed(axes.items()):\n", | |
| " idx, rem = divmod(idx, axis_size)\n", | |
| " coords[axis_name] = rem\n", | |
| " return dict(reversed(coords.items()))\n", | |
| "\n", | |
| " def coordinates_to_linear_idx(coords: dict[str, int], axes: dict[str, int]) -> int:\n", | |
| " \"\"\"Linearize coordinates according to the axes order to linear idx\"\"\"\n", | |
| " idx = 0\n", | |
| " stride = 1\n", | |
| " for axis_name, axis_size in reversed(axes.items()):\n", | |
| " idx += coords[axis_name] * stride\n", | |
| " stride *= axis_size\n", | |
| " return idx\n", | |
| "\n", | |
| " # construct mapping from logical axis idx to physical mesh axis idx\n", | |
| " logical_to_mesh = []\n", | |
| " for logical_idx in range(num_devices):\n", | |
| " coords = linear_idx_to_coordinates(logical_idx, logical_axes)\n", | |
| " mesh_idx = coordinates_to_linear_idx(coords, mesh_axes)\n", | |
| " logical_to_mesh.append(mesh_idx)\n", | |
| "\n", | |
| " src_to_dst = []\n", | |
| " for src_logical_idx in range(num_devices):\n", | |
| " dst_logical_idx = (src_logical_idx + offset) % num_devices\n", | |
| " src_mesh_idx = logical_to_mesh[src_logical_idx]\n", | |
| " dst_mesh_idx = logical_to_mesh[dst_logical_idx]\n", | |
| " src_to_dst.append((src_mesh_idx, dst_mesh_idx))\n", | |
| " return jax.lax.ppermute(block, axis_name, perm=src_to_dst)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "46da464b-080e-4248-8993-717a9932bc24", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0 1 2 3 4 5 6 7]\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[7 0 1 2 3 4 5 6]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[1 2 3 4 5 6 7 0]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=P((\"tp\", \"dp\")),\n", | |
| " out_specs=(P((\"tp\", \"dp\")), P((\"tp\", \"dp\"))),\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"n\"],\n", | |
| ") -> Float[Array, \"n\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[8@(tp,dp)] -> i32[1]{tp=2,dp=4} varying -> i32[8@(tp,dp)]\n", | |
| " axis_name = (\"tp\", \"dp\")\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift(x_block, axis_name, offset=1)\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift(x_block, axis_name, offset=-1)\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "116b4c35-b55a-409f-b427-7dbaab5de87a", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Reduce-Scatter \n", | |
| "Reduce-Scatter along a device mesh axis $i$ of size $n$ can be implemented as $n-1$ PPermute operations along that axis.\n", | |
| "\n", | |
| "Communication cost: $O(n(n-1))$\n", | |
| "\n", | |
| "Use case: FSDP gradient update; Tensor parallel forward / gradient update" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "41ff0e30-f9d1-4d59-ad37-6404e6e2ece0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(int32[2]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[1]{tp})\n", | |
| "y=Array([2, 4], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[2@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=None, in_specs=P(\"tp\"), out_specs=P(\"tp\")\n", | |
| ") # use default mesh in context (set by set_mesh / with set_mesh(mesh)) if mesh is not specified\n", | |
| "def reduce_scatter_demo(x_block: Float[Array, \"tp_block\"]) -> Float[Array, \"tp_block\"]:\n", | |
| " # i32[4@tp=2] -> i32[2{tp=2}] (varying) -> i32[1{tp=2}] -> concate along dim=0 -> i32[2@tp]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=0, tiled=True)\n", | |
| " # tiled=True ^ means its transpose (all-gather) will concatenate the results to produce the reduced block\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[1]{tp}\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(4))\n", | |
| "# x = [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> reduce scatter -> [2,4] -> int32[2@tp=2] [2] on tp=0, [4] on tp=1\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f69ff137-cc93-4404-8b33-2cf7739eda74", | |
| "metadata": {}, | |
| "source": [ | |
| "`jax.lax.psum_scatter` <- transpose -> `jax.lax.all_gather`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "bf72425a-1688-4d57-9f1c-511758743a1e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(int32[2,2]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[1,2]{tp})\n", | |
| "y=Array([[ 2, 4],\n", | |
| " [10, 12]], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[2@tp,2])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[1{tp=2}, 2] -> concate along dim=0 -> int32[2@tp, 2]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=0, tiled=True)\n", | |
| " # tiled=True means its transpose (all-gather) will concatenate the results along the dim=0\n", | |
| " # to produce the reduced block, so reduce_scatter will keep that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[1{tp=2},2]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [[2, 4]] tp=0 int32[1{tp=2},2] Final [[2, 4]\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [[10,12]] tp=1 int32[1{tp=2},2] Output [10,12]]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=0, tiled =True -> int32[2@tp,2]\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "fd9eddbb-4f7a-4504-929a-3a560e7a0e83", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2,1]{tp})\n", | |
| "y=Array([[ 2, 4],\n", | |
| " [10, 12]], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[2,2@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(None, \"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[2, 1{tp=2}] -> concate along dim=1 -> int32[2, 2@tp]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=1, tiled=True)\n", | |
| " # tiled=True means its transpose (all-gather) will concatenate the results along the dim=1\n", | |
| " # to produce the reduced block, so reduce_scatter will keep that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2,1{tp=2}]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [[2], | [[4], int32[2,1{tp}] (left on tp=0) Final [[ 2, 4],\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [10]] | [12]] int32[2,1{tp}] (right on tp=1) Output [10,12]]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=1, tiled =True ----> # int32[2,2@tp]\n", | |
| "# concate per our_spec along dim=1\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "b028daf9-4e71-4ead-8b82-019927e52a68", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2]{tp})\n", | |
| "y=Array([ 2, 4, 10, 12], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[4@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"(batch tp_block)\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> int32[2]{tp=2} -> concate along dim=0 -> int32[4@tp]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=0, tiled=False)\n", | |
| " # tiled=False (default) means its transpose (all-gather) will stack the results along dim=0 (add one dimension)\n", | |
| " # to produce the reduced block, so reduce_scatter will remove that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2]{tp=2}\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [2, 4] tp=0 int32[2]{tp} Final\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [10,12] tp=1 int32[2]{tp} Output [2,4,10,12]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=0, tiled =False -> concate along dim =0 -> int32[4@tp]\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "4017a413-562f-46e4-ba29-5e0046fe250c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2,1]{tp})\n", | |
| "y=Array([[ 2],\n", | |
| " [10],\n", | |
| " [ 4],\n", | |
| " [12]], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[4@tp,1])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[2{tp=2},1] -> int32[4@tp,1]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=1, tiled=True)\n", | |
| " # tiled=True means its transpose (all-gather) will concatenate the results along the dim=1\n", | |
| " # to produce the reduced block, so reduce_scatter will keep that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2{tp=2},1]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [[2], | [[4], int32[2{tp},1] (left on tp=0) Final [[ 2],\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [10]] | [12]] int32[2{tp},1] (right on tp=1) Output [10],\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=1, tiled =True ----> [ 4],\n", | |
| "# concate per our_spec along dim=0 [12]] # int32[4@tp,1]\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "c919afd4-e908-4a3a-b749-af15b8588c09", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2]{tp})\n", | |
| "y=Array([ 2, 10, 4, 12], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[4@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[2]{tp=2} -> int32[4@tp]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=1, tiled=False)\n", | |
| " # tiled=False means its transpose (all-gather) will stack the results along the dim=1\n", | |
| " # to produce the reduced block, so reduce_scatter will remove that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2]{tp=2}\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [2, 10] int32[2{tp=0}] Final [2,10, 4,12]\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [4, 12] int32[2{tp=1}] Output int32[4@tp]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=1, tiled=False ---->\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "41fd2bf9-0e83-49ef-b6a9-ca6f61028d39", | |
| "metadata": {}, | |
| "source": [ | |
| "#### All-Gather\n", | |
| "All-Gather along an axis $i$ of size $n$ is equivalent to Gather ($O((n-1)*c)$) + BroadCast ($O( (n-1) *(nc))$)\n", | |
| "\n", | |
| "Communication cost: $O((n-1) * c + (n - 1) * (nc) ) = O((n^2-1)c) = O(n^2)$ \n", | |
| "\n", | |
| "The implementation of All-Gather is still $n-1$ PPermute, meaning each GPU sends its chunk to its neighbor and loop for $n-1$ times. \n", | |
| "\n", | |
| "Use case: weights pre-fetching during FSDP forward step" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "id": "7ca3f8e3-7cbd-4349-a72e-4a63f92c6932", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w_block)=ShapedArray(int32[2]{tp})\n", | |
| "jax.typeof(w)=ShapedArray(int32[4])\n", | |
| "[0 1 2 3]\n", | |
| "int32[4]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P())\n", | |
| "def all_gather_invariant_demo(\n", | |
| " w_block: Float[Array, \"tp_block\"],\n", | |
| ") -> Float[Array, \"(tp tp_block)\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> concate along axis=0 -> i32[4] -> replicate\n", | |
| " print(f\"{jax.typeof(w_block)=}\") # int32[2]{tp} <- varying manual axis tp\n", | |
| " w = jax.lax.all_gather_invariant(w_block, \"tp\", axis=0, tiled=True)\n", | |
| " # concatenate if tiled = True, stack if tiled=False (default)\n", | |
| " print(f\"{jax.typeof(w)=}\") # int32[4] <- unvarying\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_invariant_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather -> [0, 1, 2, 3] replicated across all device mesh axis\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "id": "760b2266-3b3d-4c56-aa19-42ff0256c3b2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w)=ShapedArray(int32[2,2])\n", | |
| "[[0 1]\n", | |
| " [2 3]]\n", | |
| "int32[2,2]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P())\n", | |
| "def all_gather_invariant_demo(\n", | |
| " w_block: Float[Array, \"tp_block\"],\n", | |
| ") -> Float[Array, \"tp tp_block\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> stack along axis=0 -> i32[2,2] -> replicate\n", | |
| " w = jax.lax.all_gather_invariant(w_block, \"tp\", axis=0, tiled=False)\n", | |
| " # concatenate if tiled = True, stack if tiled=False (default)\n", | |
| " print(f\"{jax.typeof(w)=}\") # int32[2,2] <- unvarying\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_invariant_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather -> [[0, 1], [2, 3]] across tp\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "id": "eb655670-a8c1-4cb1-9e06-d6619d31354c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w)=ShapedArray(int32[2,2])\n", | |
| "[[0 2]\n", | |
| " [1 3]]\n", | |
| "int32[2,2]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P())\n", | |
| "def all_gather_invariant_demo(\n", | |
| " w_block: Float[Array, \"tp_block\"],\n", | |
| ") -> Float[Array, \"tp tp_block\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> stack along axis=1 -> i32[2,2] -> replicate\n", | |
| " w = jax.lax.all_gather_invariant(w_block, \"tp\", axis=1, tiled=False)\n", | |
| " # concatenate if tiled = True, stack if tiled=False (default)\n", | |
| " print(f\"{jax.typeof(w)=}\") # int32[2,2] <- unvarying\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_invariant_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather stack along dim=1 -> [[0, 2],\n", | |
| "# [1, 3] across all device mesh axes\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "efccc5f9-b41b-4329-a606-46e8d7fea17c", | |
| "metadata": {}, | |
| "source": [ | |
| "Note that `jax.lax.all_gather_invariant` is a varying-to-unvarying op, whereas `jax.lax.all_gather` is a varying-to-varying op. This is only because the transpose of `jax.lax.all_gather` is `reduce_scatter` (`jax.lax.psum_scatter` to be more specific), which is a varying-to-varying op. Whereas the transpose of `jax.lax.all_gather_invariant` is `jax.lax.dynamic_slice`, which is a unvarying-to-varying op.\n", | |
| "\n", | |
| "See [varying manual mesh axis (vma)](https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true) for more information. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "id": "6a75bd59-f1fa-4ad5-ad3c-ee93a089c581", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w_block)=ShapedArray(int32[2]{tp})\n", | |
| "jax.typeof(w)=ShapedArray(int32[4]{tp})\n", | |
| "[0 1 2 3 0 1 2 3]\n", | |
| "int32[8@tp]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P(\"tp\"))\n", | |
| "def all_gather_demo(w_block: Float[Array, \"tp_block\"]) -> Float[Array, \"(tp tp_block)\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> concatenate along axis=0 -> i32[4]{tp} varying -> concatenate along axis=0 -> i32[8@tp]\n", | |
| " print(f\"{jax.typeof(w_block)=}\") # int32[2]{tp} <- varying manual axis tp\n", | |
| " w = jax.lax.all_gather(w_block, \"tp\", axis=0, tiled=True)\n", | |
| " print(f\"{jax.typeof(w)=}\") # Attention: int32[4]{tp} <- varying\n", | |
| " # Note: because w is varying along tp, the output sharding spec cannot be P(), but has to be P(\"tp\")\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather -> [0, 1, 2, 3] varying across tp -> concate tp partitions along dim=0 -> i32[8@tp]\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b33d8d1a-c223-438d-afa4-d99eefae1c26", | |
| "metadata": {}, | |
| "source": [ | |
| "#### All-Reduce\n", | |
| "All-Reduce along a device mesh axis $i$ of size $n$ can be implemented as Reduce-Scatter ($O(n^2)$) + All-Gather ($O(n^2)$) \n", | |
| "\n", | |
| "Communication cost: $O(2n^2)$ twice the communication cost of Reduce-Scatter or All-Gather\n", | |
| "\n", | |
| "Use case: data parallel gradient update; tensor parallel (row parallel linear) forward pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "c6ae459f-78e8-48ba-9081-8949d2379426", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(int32[2]{dp})\n", | |
| "jax.typeof(y_all_reduced)=ShapedArray(int32[2])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"dp\"), out_specs=P())\n", | |
| "def all_reduce_1d(x_block: Float[Array, \"block_size\"]) -> Float[Array, \"block_size\"]:\n", | |
| " # i32[8@dp] -> i32[2{dp=4}] -> all-reduce -> i32[2] unvarying -> replicate\n", | |
| " print(f\"{jax.typeof(x_block)=}\") # int32[2]{dp} <- varying manual axis dp\n", | |
| " x_block = x_block + 1 # parallel execution\n", | |
| " y_all_reduced = jax.lax.psum(x_block, \"dp\")\n", | |
| " print(f\"{jax.typeof(y_all_reduced)=}\") # int32[2] <- unvarying manual axis dp, tp\n", | |
| " return y_all_reduced\n", | |
| "\n", | |
| "\n", | |
| "y = all_reduce_1d(jnp.arange(8) - 1)\n", | |
| "assert jnp.all(y == jnp.sum(jnp.asarray([[0, 1], [2, 3], [4, 5], [6, 7]]), axis=0))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "58309258-4142-4599-8d43-16b10499a87f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(float32[256,4096]{dp})\n", | |
| "jax.typeof(w_block)=ShapedArray(float32[4096,128]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(float32[256,128]{tp,dp})\n", | |
| "jax.typeof(y_block_all_reduced)=ShapedArray(float32[256,128])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=(P(\"dp\"), P(None, \"tp\")), out_specs=P())\n", | |
| "def all_reduce_2d(\n", | |
| " x_block: Float[Array, \"dp_block hidden\"], w_block: Float[Array, \"hidden tp_block\"]\n", | |
| ") -> Float[Array, \"dp_block tp_block\"]:\n", | |
| " # f32[1024@dp=4, 4096] @ f32[4096, 256@tp=2]\n", | |
| " # f32[256{dp=4}, 4096] @ f32[4096, 128{tp=2}] -> f32[256{dp=4}, 128{tp=2}]\n", | |
| " print(f\"{jax.typeof(x_block)=}\") # float32[256,4096]{dp=4, None}\n", | |
| " print(f\"{jax.typeof(w_block)=}\") # float32[4096,128]{None, tp=2}\n", | |
| " y_block = x_block @ w_block\n", | |
| " # float32[256,128]{tp,dp} <- varying manual axis along dp and tp\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " y_block_all_reduced = jax.lax.psum(y_block, (\"dp\", \"tp\"))\n", | |
| " # float32[256,128] <- unvarying manual axis\n", | |
| " print(f\"{jax.typeof(y_block_all_reduced)=}\")\n", | |
| " return y_block_all_reduced\n", | |
| "\n", | |
| "\n", | |
| "y = all_reduce_2d(x, w)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e93a7b61-7f84-432a-af0a-749bd6d7e77f", | |
| "metadata": {}, | |
| "source": [ | |
| "#### All-to-All\n", | |
| "Use case: MOE expert routing \n", | |
| "\n", | |
| "Think of `all_to_all` as a **block transpose between a local axis and a mesh axis**. Every device sends a slice to every other device, and receives a slice from every other device. \n", | |
| "\n", | |
| "`all_to_all(x, axis_name, split_axis, concat_axis)`: each device splits its local block `x` along `split_axis` into `prod(*axis_name)` blocks (so `split_axis` size must divide `jax.lax.axis_size(axis_name)`), sends local block `i` to device `i`, receives one block from every other device, and then concatenates received blocks (gather) along `concat_axis`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "id": "466d8ff2-4657-46c9-985c-b339df1d3470", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x=Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", | |
| " [10, 11, 12, 13, 14, 15, 16, 17],\n", | |
| " [20, 21, 22, 23, 24, 25, 26, 27],\n", | |
| " [30, 31, 32, 33, 34, 35, 36, 37],\n", | |
| " [40, 41, 42, 43, 44, 45, 46, 47],\n", | |
| " [50, 51, 52, 53, 54, 55, 56, 57],\n", | |
| " [60, 61, 62, 63, 64, 65, 66, 67],\n", | |
| " [70, 71, 72, 73, 74, 75, 76, 77]], dtype=int32)\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[8]{tp,dp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[8]{tp,dp})\n", | |
| "x_all_to_all=Array([[ 0, 10, 20, 30, 40, 50, 60, 70],\n", | |
| " [ 1, 11, 21, 31, 41, 51, 61, 71],\n", | |
| " [ 2, 12, 22, 32, 42, 52, 62, 72],\n", | |
| " [ 3, 13, 23, 33, 43, 53, 63, 73],\n", | |
| " [ 4, 14, 24, 34, 44, 54, 64, 74],\n", | |
| " [ 5, 15, 25, 35, 45, 55, 65, 75],\n", | |
| " [ 6, 16, 26, 36, 46, 56, 66, 76],\n", | |
| " [ 7, 17, 27, 37, 47, 57, 67, 77]], dtype=int32)\n", | |
| "jax.typeof(x_all_to_all)=ShapedArray(int32[8@(dp,tp),8])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " in_specs=P((\"dp\", \"tp\")),\n", | |
| " out_specs=P((\"dp\", \"tp\")),\n", | |
| ")\n", | |
| "def all_to_all_demo(x_block: Array) -> Array:\n", | |
| " # i32[64@(dp,tp)] -> i32[8{dp,tp}] -> split along axis=0 to dp x tp partitions -> all-to-all send\n", | |
| " # each local split i32[1{dp,tp}] to the corresponding device on mesh axis (\"dp\", \"tp\")\n", | |
| " # -> each (dp,tp) device receives dp x tp partitions -> concate along axis=0 -> i32[8{dp,tp}] -> i32[64@(dp,tp)]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.all_to_all(\n", | |
| " x_block, axis_name=(\"dp\", \"tp\"), split_axis=0, concat_axis=0, tiled=True\n", | |
| " )\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "x = 10 * jnp.arange(8)[:, None] + jnp.arange(8)[None, :]\n", | |
| "print(f\"{x=}\")\n", | |
| "x_all_to_all: Array = all_to_all_demo(x.ravel()).reshape(x.shape)\n", | |
| "print(f\"{x_all_to_all=}\")\n", | |
| "print(f\"{jax.typeof(x_all_to_all)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "id": "a57c6abf-f855-4efd-8c0c-0bae94ec574e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x=Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", | |
| " [10, 11, 12, 13, 14, 15, 16, 17],\n", | |
| " [20, 21, 22, 23, 24, 25, 26, 27],\n", | |
| " [30, 31, 32, 33, 34, 35, 36, 37],\n", | |
| " [40, 41, 42, 43, 44, 45, 46, 47],\n", | |
| " [50, 51, 52, 53, 54, 55, 56, 57],\n", | |
| " [60, 61, 62, 63, 64, 65, 66, 67],\n", | |
| " [70, 71, 72, 73, 74, 75, 76, 77]], dtype=int32)\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[2,4]{tp,dp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[4,2]{tp,dp})\n", | |
| "x_all_to_all=Array([[ 0, 10, 20, 30, 40, 50, 60, 70],\n", | |
| " [ 1, 11, 21, 31, 41, 51, 61, 71],\n", | |
| " [ 2, 12, 22, 32, 42, 52, 62, 72],\n", | |
| " [ 3, 13, 23, 33, 43, 53, 63, 73],\n", | |
| " [ 4, 14, 24, 34, 44, 54, 64, 74],\n", | |
| " [ 5, 15, 25, 35, 45, 55, 65, 75],\n", | |
| " [ 6, 16, 26, 36, 46, 56, 66, 76],\n", | |
| " [ 7, 17, 27, 37, 47, 57, 67, 77]], dtype=int32)\n", | |
| "jax.typeof(x_all_to_all)=ShapedArray(int32[8@(tp,dp),8])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " in_specs=P(\"dp\", \"tp\"),\n", | |
| " out_specs=P(\n", | |
| " (\"tp\", \"dp\")\n", | |
| " ), # exercise question: why P((\"tp\", \"dp\"))? What if P((\"dp\", \"tp\"))?\n", | |
| ")\n", | |
| "def all_to_all_demo(x_block: Array) -> Array:\n", | |
| " # i32[8@dp, 8@tp] -> i32[2{dp}, 4{tp}] -> split along axis=1 to #dp local partitions i32[2{dp=4}, {tp}] (unstack split_axis because tiled=False)\n", | |
| " # -> send each local partition to the corresponding device on mesh axis \"dp\" -> each dp device receives 4 i32[2,] from devices with matching\n", | |
| " # remaining mesh axes but varying dp index -> stack along axis=0 -> i32[4, 2] -> reshape to i32[1, 8] -> concate output local block along axis=0 per out sharding spec\n", | |
| " # -> i32[8@(tp,dp), 8]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.all_to_all(\n", | |
| " x_block, axis_name=\"dp\", split_axis=1, concat_axis=0, tiled=False\n", | |
| " )\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " return y_block.reshape(1, -1)\n", | |
| "\n", | |
| "\n", | |
| "x = 10 * jnp.arange(8)[:, None] + jnp.arange(8)[None, :]\n", | |
| "print(f\"{x=}\")\n", | |
| "x_all_to_all: Array = all_to_all_demo(x)\n", | |
| "print(f\"{x_all_to_all=}\")\n", | |
| "print(f\"{jax.typeof(x_all_to_all)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1a6afc9a-0517-4e88-80c8-03ab626e9f97", | |
| "metadata": {}, | |
| "source": [ | |
| "## Matrix Multiplication " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e6503226-47ed-4a1a-b54c-090c906769b5", | |
| "metadata": {}, | |
| "source": [ | |
| "### Naive Version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "id": "82da5ba3-fb46-43d9-a868-f1780d48e319", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "13.6 ms ± 145 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = jax.random.normal(jax.random.key(0), (1024, 4096))\n", | |
| "w = jax.random.normal(jax.random.key(1), (4096, 256))\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "def matmul(\n", | |
| " x: Float[Array, \"batch n_in\"], w: Float[Array, \"n_in n_out\"]\n", | |
| ") -> Float[Array, \"batch n_out\"]:\n", | |
| " \"\"\"Naive matrix multiplication\"\"\"\n", | |
| " return x @ w\n", | |
| "\n", | |
| "\n", | |
| "y = matmul(x, w)\n", | |
| "\n", | |
| "%timeit y = matmul(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e2e48a80-b324-4015-b7aa-024c14999a53", | |
| "metadata": {}, | |
| "source": [ | |
| "### FSDP Forward\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "id": "e093d979-1354-4f4c-a4e3-c53cf8677342", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "6.59 ms ± 257 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(\"dp\", None)), out_specs=P(\"dp\", None)\n", | |
| ")\n", | |
| "def matmul_fsdp(\n", | |
| " x_block: Float[Array, \"batch_local n_in\"],\n", | |
| " w_block: Float[Array, \"n_in_local n_out\"],\n", | |
| ") -> Float[Array, \"batch_local n_out\"]:\n", | |
| " \"\"\"FSDP matrix multiplication\n", | |
| "\n", | |
| " Data and weights are sharded across the dp/fsdp device mesh axis.\n", | |
| " An all-gather is needed during the forward. Advanced implementation\n", | |
| " use prefetching to overlap computation (in the current layer) and\n", | |
| " communication (all-gather weights in the next layer)\n", | |
| " \"\"\"\n", | |
| " # f32[1024@dp, 4096] @ f32[4096@dp, 256]\n", | |
| " # f32[256{dp=4}, 4096] f32[1024{dp=4}, 256]\n", | |
| "\n", | |
| " w_full = jax.lax.all_gather(w_block, \"dp\", axis=0, tiled=True) # f32[4x1024, 256]\n", | |
| " y_block = x_block @ w_full\n", | |
| " # f32[256{dp=4}, 4096] @ f32[4096, 256] -> f32[256{dp=4}, 256]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "assert jnp.allclose(matmul_fsdp(x, w), y) # verify correctness\n", | |
| "%timeit matmul_fsdp(x,w).block_until_ready() # fsdp is faster than non-paralell version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "30fd9e0a-c390-47a2-ac17-ce39a098e10a", | |
| "metadata": {}, | |
| "source": [ | |
| "### Profiling matmul_fsdp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "b9f012a6-0395-4a13-9dcd-511e10596e5b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "E0201 17:26:43.205688 741064 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n", | |
| "E0201 17:26:43.231453 741064 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "jax.profiler.start_trace(log_dir=\"./jax_profiler\")\n", | |
| "_ = matmul_fsdp(x, w).block_until_ready()\n", | |
| "# ^ need block until ready within profiler context\n", | |
| "jax.profiler.stop_trace()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "b3e9318d-e0a0-4c57-ac7e-91eba7007e47", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# %pip install tensorboard-plugin-profile # <- uncomment to install profile plugin if it does not render\n", | |
| "%load_ext tensorboard" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "f35e62e9-1d16-4b9b-a41e-fa3fd5ebc4d3", | |
| "metadata": { | |
| "collapsed": true, | |
| "jupyter": { | |
| "outputs_hidden": true | |
| }, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <iframe id=\"tensorboard-frame-66630db69001805b\" width=\"100%\" height=\"800\" frameborder=\"0\">\n", | |
| " </iframe>\n", | |
| " <script>\n", | |
| " (function() {\n", | |
| " const frame = document.getElementById(\"tensorboard-frame-66630db69001805b\");\n", | |
| " const url = new URL(\"/\", window.location);\n", | |
| " const port = 6006;\n", | |
| " if (port) {\n", | |
| " url.port = port;\n", | |
| " }\n", | |
| " frame.src = url;\n", | |
| " })();\n", | |
| " </script>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "%tensorboard --logdir ./jax_profiler" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4d8320de-7aad-42d9-86d7-768611b04039", | |
| "metadata": {}, | |
| "source": [ | |
| "As you can see from the trace viewer, computation needs to wait until all-gather weights is done. Can we do better? " | |
| ] | |
| }, | |
| { | |
| "attachments": { | |
| "2da00048-5b90-4985-afe0-96c38199a47d.png": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAABXcAAAVUCAYAAACbbiv3AAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsMAAA7DAcdvqGQAAP+lSURBVHhe7N15XFT1/sfx9zCshqaUJZYlZiZuud0SqhvKDcutUqsrZIuaLWq3wjbT0jLbpMUly9QWE7sltrh0pStiC2gpapp4+ZVomlAUkpLszO+PGYaZw7Aoohx9PR+P82jm+z3nzJkzB8L3fM/na/lrsmwCAAAAAAAAAJiKl7EBAAAAAAAAAND4Ee4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACVn+miybsREA6sr3fpu8W7q32TInqmBRnHsjAAAAAAAAjiuThLs9ZO0eKoux+XgqzVbZjiQznIzT3DFeC4XpKt21xdiK44BwFwAAAAAA4OQwSbi7QP7PjG7YGhKFSSp5NlIlxnY0MvW8Fo5kqnxPvEq/nKzSX4ydOBaEuwAAAAAAACfHMWdkgCk1CZFXpyfkO/ag/IeNO/oRwAAAAAAAAEAjQbiL05NXc3l1nyP/UU/wQwAAAAAAAABTItfCac0SMlm+Nw4wNgMwCa+O0fK+ZoF8hy2R7zXj5N2xh3EVAAAAAABOWSapuVtHN+5Uk56hhsZ0lU7ppGJDK8zKc81dzzVeQ2TtEiaLd4i8ug6TtV0PWbwNq0hScapKngmn3vIxoubuqSBWfk/MlNXf2O5BeaFs+Zkqz1io0q/iVJZrXMHOZ9RB+YQ0NzZXVV4o26F0lf+4UCWfzlW5sd+ToGj59J8s7w6hnn+my/NU/uNCla6aqNJqjk+q7v8ZdZCzUEdmjXE+df8ZyFPZ5y1UlOLs9nB+j+3/S3U+p0bUlAcAAACAo2K9ZKC8zu0mS8tQqSBXZfs2qHzv17Id2m9c9aQj3IXJHE24axA0Tn6j58jazNhRqPLkABWuNbajLgh3TwXG8LGOSjNV9nmkir7NNPYcWxBZmK7SLwaq2MP+Kli6zZHf4HHyqsux1nB8UnX/z6gDwl0AAAAApxm/O/8ra7tI5/PSzQtV/Enlv4skybv7SPkMmiuLX1N7Q1mJSjfMVvF/Yt3WM/KJfFo+Vz0qWX0lSWW716ro7X8YV/PIq00f+d38gSzNL7Q3FP+l4sTHVLpxjnHVWnm16SOf8Ifk1T5KFv8zjd1SeanKc9JVmvqaSjcvNPZWOUfHwlZ0WCUrx6l062LJw7kp/22nCmd3dq7vP+bLKhkZcOrKnauitas9fJvhL0ubccZGALXxDpF1cJr8rgwx9hwb/1B5X7dKvh2NHQ4dF8jvRg/BbnmhVFyoKsN+vUNkHZhy/I4PAAAAAFB3Vh9Zu/1T1g7XGXuOG2tIX1kCz61s8D1D1ouvdV2lTrwvu1d+MZ/K2uUmz8GuJHl5y+vcrvId/Lp8b3jL2HtSlP2cwsjduvDq+IS8L+0nr5bBLq15su1LUtn3C1WaWc2oMAdLyFBZz3RPI2w58Sr7peJZD3lfM07WC8Nlca5WKFtOisq3LVTJri3O7erkvAHy6T5aXm0MtywXZqo8K0VlW591ee2j4/lcHO2x9pC1e6gsbm15Kt+62p7NnDdaPn2iZQ12vEZppsq+GKiSn1S/kbuSpOnyn/KEvOxfeDjVffsQWXuPljW0n7zOdB1BV/froXqOfV8UZji/kv5MUVl6vEo3JdXyA1tRisK1rVC2n5fbb58P6iefsFhZQyrCrjyVp4SrOM11/Qohsl4ZK+8O4bI0qbgwHZ/1t3EqcbxPRu6eCuo2stQSMlTe3SbKu7vhGvMwMtQ4yrQ8zaLCj11WUIisXQbIq+s4eXcMlcX1hzprrgpeH2+41mPl9/hMWZu4NBWmq3T9eJV8XfFzESLrlU/I54rR8gp0XS9JJfMiVWIs0WD4f8axXrcna+Ru1XMKAAAAAA3DOCq1TiN3JclmU9lPX6jo3f6uq7oxjk49mpG7/mO+lNeFV7m12Q79ouKE21S2O8mtvTref7tHPlHPyeJvuFOypEC28lJZLF6STxPJ4pJklZepdOu7Kv54tLPJeI6OxdGO3PVq04dwtyZefZbIt2+0vFzDBE/yklTyxRiVfO851DP+g1zOf5SHyHtYgny6VFML1sH2a7xK4mNqrh0pydJtpnwjRsvasvbbdm2/LlfplxOrPWZ3IfK+bo68uw+o/VwUZ6ps02QVfx5fw4XlKaC1f05l/VPkGx7mHvS4hSWetj2aUMYYstjVvn0PeQ9bKJ9OPWQxBMNVHElX6boYFW+oS9Atx76XyKebIeDypDhT5VvjVLRibjXn19P7s5+/4tIE+V83tMq15ikg8uqzRL6R0VVHSLqw/Rqvko9j5DWMcNf8jNdNLb83z5sp/1Gxbl+SlG9op8JVlb9PjL/3PF1nFSxXrpV//34uX/ikq/SVTip2+Z1nvWm3/Lq5jMA9kqSSNz0EtpIUFCu/u92DYNuuMSpYYrhthnAXAAAAAOrEGFzWOdyVPSQtSX5aJV8+797uYAww6xruWjsOlu8NC2Q54xz3jrISlaS8rJLEx9zbPfA6p5P8oj+R5ayLnW22Q/tVsv45lX77urPNGhIhn8in5XXBlc6Q11aQq+IV41S2/QPnetXxn/CDvM7pZH9SVqySr15QydonjatVYTw3xnBXUpWMDJKkAfIZkyX/gXUIdiWpeT/53JQm/4EDjD3VsrSYKd/70uTbveZgV5Is50bL9+618gky9lSyDkyT/02xdQp2Jcly7lD5DEuRf/9ajjko2n6c4XUIdiXJN0TW8CUKeDBB3jUcb1X+sty4U35XGoPd462F5/2XFxpbKp33hPweTrF/VrUFu5LUJFTeA9MUMGamrMY+o45z5P9Emny71yHYlf38el02RwHjF9S+b1fBq+Q/sGqw64n1xp32a7+GYFcV1+WYnfLyM/bglPfLRJVkuH8x5HVe5beVR8v29bMqz3NtCZbF8f88u9HybudaWqFQ5d+O8RzsSlJunIrWuJdgsbQbLR+X5wAAAACAE8QnQN69x8qrTR9jT71YOwyQpcnZ9ic2m2Rz1Oqz+sh60TVu61bHu88EWYIucj63Hdqv4k/HugW7klSWmayixQNVfmCzs80SECTvbiPc1jsZ6hInnWYGyPe+JfK5sJWxoxbN5dUnQf7X9zN2eGRpM07ewXULYiVJTfrJ+wbPBagt16XJr08PQ5mDOvBqJa8rl1RfjzJonPxuX3h0x1khaKh8R6+qMZB2FyJrlVHXx58lsl+VkgxSoWx75hob7TrOlP9t02VtVkvS6YHlwlj53jen+hC24wL53+KhfmhdnDtafvfNrONn3lzW7gPqFB5bIlPkezSfg3eovKpMUIfTQdm+TPfR4/5H+zvTVZLKD7qmu81lcR0N3n2oe5mF/CSVrq3lroO0OHspkgq+ofIKd3kOAAAAADhhLM3byif8QWPzMbP4NbWXY3CEHbbDB1R+oLLmpKVFiLy7RbtsUZXl7A7yCunr3IfKSlS2/UOVZXxuXFVylEwo27lcKq0cIOjVqpu8zvub23onWh3intOL9ZY51YaZtrx0lf9qX2xHjL2S5C+v7nPkWxn4V8/bJdErz5Mtx7HvvLyqkwI5WC6Mka8xLA2aKb/LehgaZQ8s9y5X6RcxKk4Yr5JNq1We72l0anNZ+y7xEMKGyOeG6bIGVZM8Hqo8F+WHPO1XUrMB8hk+vY4BZMOzdFsivyvDjM32oCjZ2Ch70H9dbLUjlmu/HiRL8Dj53uJpdPQA+Q0cLS9PI2mPpKtsR7xKv5is0q3xKvsl27iGXfA4+UUYG+shaKZ95LSxvcKRTPv7zcmWrdTYCdSTW22jQulQ5TPLRe41um05Sar9EkxS+W+uPzvN5dWmbl++AQAAAACOg5IjUkmB/bHFIq/2/eV92b3GtY6J9dKR8mp+ofN5+e+7VLZ7rVRmnw3G4t9c1ksGuWxRlbVNuCxNK+c8sv31m8oyVrmtY1T+yybZCg5WNvi3kJdLSYeTgXDXVccE+XbxMIr1SKpKPmqngrhOKpxjXwqea6filNSqdU+9Q+Xd9wljazXyVL51vAqfaqGCWY59x7XQkddiVJrjITD16iGvy92bLFcPqBoQlmeq9KNOKlgwTMVfxqt061yVfDpQhS90UvEOD6PdfMNkvdrwvsMXyNtQJ1iSVLhFpR+105GXKs9F4UsBKvhorufKBm1i5Rfp4ZzWpDhb5ZnLVbo13rEkqLya8NTJO1je3aM9L3+fI797shRwU7SHc5Wtsi8GegyKrDfO9Fxa4tBqlbxhMVwPFhWuXe0x9LR0mSO/jobGiMmyeji9tsyJKniuk4r+HaPiL59VcUKMit4IVsEaTxOp+cur00xjY83K82Tbu9rl3C5X+a/2Lu+BMVXPjyQdSlLJR+105Ll29vc7K1gF0xzXfzVfRODUZ20T4v5FQGE1X0LURdB0Wc9ybciSbU/lM+8zW7h2ynawbnVxS/Nd/ocrSWd4+iIMAAAAANAQbAUH7ZOa2eyJhsX/THn/7W5Zzu5gXPWoWS/uL/k6bvEsLVT5T2tVvme9bAWOWzgtFlla96zxtSzndpHFp3JEny3/11onYSvfv0HFqyaoOOE2+/LZPSrb86VxtROKcNeFz5UDqo5aLE1X6YfhHiYey1Tp5+Eq3mpsl9RmWB1qOxaqPC1GhQlzqw7UzY1X8fsLVV6lQ7IEu5dm8G5R+Q2D0y/xKq5yvLIf87+fVVm+sV3yaulaL7OffHu7Tm7kUJqu0oSeHvdt+368ChMWqrxKuOkvr27Tqy9N4KZQtp3jVfhMsAoXDVNxQoxjGaOSrcZ13VnaxMp32BLPyzXjZD3Pwy3jpZkq+yJcRZWj9l3EyruTh/IER5JUsnCgSn4xdkjlyQNVuNZTCBsi65Xugb93Ww/7LkxS6aI4D9tLtq8jVbrXQ3reokcdrjU72744Fb/WQgULBrqc22EqTpGkJ+Td1sM5KkxSycLIaq//wi88vV+c8joukE9H9y9tyn8xTFZWV+eNk9/t7pOzKTdVpT+5PHcb1ZsnW47L05rkZLnX3Q308HPnwhLYr+rvD+NyZeMZ/evVrUBNptS8BMSMM24GAAAAACdM6bb3ZcurHL3jdU4X+fS5322do+XVpo+8Wl3qfG7L/1VlmetUlvG5yn/bUble8wvl3WmY87mRxb+5XEe52QoNA4Q8sBUdVtkPCSrdulilWxerbPsHsh3ab1zthCLcdZopa5uqJQhsO8er2DVkMChLWKhy45TnXqGyVk4g6FnuchV9vNrYWik3zjDBkJ2liXs4YWlSdfinrbSmi3Ghyg9kS8WF7ouXy34uGiera71LB9vO8SreZWx1sWuMSr5PN7ZKQf3k3d3YWJUtc7IKl3oIu4+7Qtn2zlXx7HYq+toYWjpEDnOZ2b5CocpTI6ufxKkihM308MEZAn/bj0tcRs86lm/jZb95wLOSXzycW//gqiG8J7nxKp4/UaXVHXvkQA+1iKXyrTVMWiXJ9vUY97qmOEX4y2IcAd89Wt7XLJDfqJ0KGGEoKVKYpLJV1fwsOXi13yn/8YYl9qCajJ3jXv6lPFtl62NU5rJtlYkAG+qXRMsBVd+zYbF2aESjf739Jd9aFr8qv8gAAAAA4MT56zeVbn2vsk6tl1XWLjfJ2ulG45p1Zg29QZamrZ3Py7O3qXzfBvvjn1OkMkdQ5+0vrwuucK5nZGnR1u257c+TG9IeK8LdCtf1k1eVs5GusrU1D8eWnlXZH8Y2f1la1Vy0WWUFtYx4zJTtTw8hoeEYy/OrrmM5L0a+xjIALkoWB+vIMwHuyxvjnf2W7h4mZyvfUodzIZV9vNrDiONW8rqkttFueSrf5XnU6vHnL8uF4+R7V5p8L/NcMsLnQg8j/KqtzeuuJC2l6vswBP5lX493GT3rWL6oZeTjnwer7reOyve4h2VG3m08vN/yLbUGdlKmatwxTCpEVuOI1WFL5Pv30bKGhLpPzlearbL1Y2r8YkKS1CxUXucalubN3X+n1TiaHgAAAABgRqXfxKns5xTnc8sZ58j78vGy+DV1W6+urBeES1bHELqSApX//I2zryxjtWyHs5zPvc7rLWuH65zPa1T1dnRTqBJnnq68z/UQ8uVnqrwuoxILPQSsZ/Y0NjWI0j0eRnP69pD3iIPyH7VAPh2PfpSZ91keSj3kpdc4grPSRJV7uGXacpanScWOo3LDSGTXpUrY7BDYQ96D0+Q/0Hhs/eTVwsOI6DpN4iRp62oPNYLrEPi78OpYdcSgd9MA42rHjVdg1fervPS6vV+cvvKSVPJxePUj4I+C7Ze5Kl5QzWh6489wQ/2fq2LSwBoW2x+VfyScbLYdhi+IPC3rlhs3AwAAAIATylZ0WKXfzJTt8AFnm/WCcHlf4V56tC68u0XLck4X53Pb4V9Umv6J83n5vg0q/7WyNIMlIEheba92Pj8VNdQ/kU3HY7jVpJ98PdQwNC4+F3rY1ttDW0NInqzSXz3UYvVqLq+Q0fKJSVOTaQUKiE2T3y1z6hT2Wvyr3sZr+zOtzqNGbUeqht3yq2P5gGNk2zu56mjkiuUpi468MVAlW9M9TADWXF59FsrPLYvvIVU9BdKfdQ1J5sr2l7HNfg6q49VzuvzG7FbAlAI1ecYm/xgPoyavDGugc9hPFj9j29F95jgFGb8kcSy2Q+kqz4xXyaeRKojzVI+5GocMQWmu++8JyxnBsnmoZa0qv1Oay+KhbIxHLd1/79jyPXwZ5sL269zKiSKrWz6NN2520tiKDaVdPCxlmXX8fAAAAACgAZVlfK6y7z+Qyhz3fXr7y7v7bfJq08e4ao282l4li/+Zzuflv2yW7fcMt3XKf/6msgyE1UfW9lF1GyXscZb5xo9wt4KnM+HloX6hp8XTtidMkorjJ6vskIeAt4KXvyzNe8jaZZw97J2SJf8R0+UdZFxRkmJlqcP1XpPSPz3U/G3SSif1R+SX1SpJ6KSCpfEeAt5WsvaZ4xIChcriIdy1ldc9JLF5KpfhYTInS7eZ8ostkP+NT8h6YYgsvh5euMFVE2bjNJauUuOXJI6l4KVOKlwUo5JNRzeZXvmPhpD0lYnukzs2HyCf6zzcQeHhd4qlRd2+3fUObOHe8NcW9+cAAAAAgBOmZN1UlWdvcz63NG8rn/AH3dapieXsDvIK6StnrcCSApXn7ZF395Fui63wT9mOVNZQtTRvK+slgyt35GA7WDnRmyRZzjzf7blZnNRYEsdJbpyKXgpXsceRqR74tpJXpyfk+6+D8h8cfdxHg9rKawiaT7ZdMSr5MdvYKgX3k4/HsLvhWK5cK/9hsbI2J1nF6Wihir93DVv95dVtjscvgWw/pbsFyZaW/Tyu566fvM5p5fI8T+X7aq8bDgAAAABoGLaiwyr9dp5sBY4BPBaLvNr3l6WF54E+RtaLouTVzCWA9QmQz1WPynfYe+7L4LmyNDvPuZrFv7mslwyq3M7BVpjnVmfX4m8YIOSBxa+prJ2HOYNka9d/yuJ6TCcB4e4pY4tKEzqp4LVIlWxKUvmhqiNHq/BqLq/LFso/JtYl4K2hRm0deZ9ZffmBxqA0c6eHEYch8qq9YkWdWZrUUpbjogXyi+znPjGVgy0nSaVfjldxwjAVvWLRkSmO5fOjGylZd1miuC5OBtvnc91H7wYOkPeNHiZf3Lpc5W7r9ZN3ZC3/8+8ZK6vrFzbF6SqvrN8PAAAAADgJStMWqSxjtSpGJ1r8z5S1w8DK0bg1sF58reR7hrG5dhaLvM7/myxnd3Brtv22s7J8gyRL4LmytvPwb1IXXuf3ke/A2ZVB8pA3ZG37d+NqJ1TtZ+50lhVXGawd7TJrjHFvJ0Zukko+jVThSy10ZEpPFa16VqUZqTWEvf6ydJws3/CK59XUiz0KHn8eD2fKUVXl5MvydC78pWYVj9Nl8zD42OJVS5jkFCI5Jm105Vrv06dvjIdSLnkqW9tTBbMiVfzFXJVuXa6yOk1iV1/xshUY24ATYaGKN6W6tXh1me5hFP1ClWa41sv1l9dlCzys5xAULb++A9zr7e5e2Hh+BwEAAADAaaz0yxmy5f7kfG4JaFFrvVuvNn3kdW7lRGoqL5OK82UrOlztotIi5+qWZufLu9Owyu0llf38tcoPVU7+YjnjHHm1j3Jbx8jrvN72461QeFDlf/yf6yonnKcY7rTkqUaqmobW4dbfxmyLyjZMVvHicHvY+8YYlWZ6KEmg5rJ2me585ulcWFqEy2ps9KifvFp4GLVa6Ol1T5KWHo7PzXLZjhjbJMtZo41N1Rgtr0Bjm2u9z1h5nVu1FIMtY6KKkk9OTVCPn/mZPY97yQ7AyLZ2uspcLz/fMHlfM8Clwa7s4zj3Ub5N+snn3p3yvbKfy3UaIuuVC+Q/eqGsrj/mR5JU+vlClwYAAAAAwMlS/ttOlaYtkkrqPtLMGnqDLE1bO5+X/75LBS+2VsH0ZtUuZTuXSzbHPdDe/vK6KLJyh5Jsv2eoPHOdcxSxrD7yvjRG1g7Xua1XweLXVNZOQyXvykyn/NcfVP7Ld27rnWiEuw4lf2QZm6TAnrL2NDY2Ij2ny3fYEvflmhoCyF8WqnhRuEr2eRiWelaYM8j2eC6ah7rf4lydoGhZPGSn5TmNJ1ixnNXCQ2hZKB2qeJyp8oNVw061DPM0ILeq8DBZfI2NrvU+PU3Ylqfyn07eOTJOWCXZP/NqR0YCx81qFW12H71r6eh59G7RpwtdyyFJ/qHy7r9WAdMK1GRKgZpM2y2//qPl1czlB6w8W2VfjVHJCRkFX1eh8n7Gpia1LjtV5VeJC6+exvWrW2reDwAAAACcaCVfPq+yvV9Vhq+1sF54lWR1pDK2cpXv/co+OrcGZXu/lkoqb0/3OqdLleC2dMu7suXtdT63NDtfvoPmyPuy+9zWs7bvL7/RX8qrdS9nm60gV6Xb3ndb72Qg3K3wVYqHUrOtZO0z00MQ6EmIfAbPkXeVQKIBXThU3t2j3ZfLomsJIDNVkpNpbJS8VPk+PZ6LUFmvqyE4dvC+bqCHiypT5d97eM2TYrR8u3gqrpslm8skiaUZaVXr2/qHy3pdbaUZQuTTPbzqNVOXep9VT5wbS1NPofTxYZywSpLk1UNWDyMo3Y2WJcDYBhyl5MkqzXF57t1D3jfEujQ47Bqjwn/PVZU5G738JV//qj9DpZkqWxWuoq8by+8fAAAAAECF0tRXZcv3MMDQwLtbtCwtQ53PbYV/qnzPV27reFK2bbHKXYPbM1ra6/u6KN+3QSVfvWifXK1ivRbt5Dt4rpo8eUQBkw+pyZN/ye+2z+UV3F2yOJKZ8lKV7fxYZds/qNzZSWL8p/DpK/dZlWUZEwNJwePkf8u4WkK1HvIZlSafy8bJd/QqDyPOGkhmZtVAzj9c1lomGvIO9DD7X2lh5b6qOReWjjPld2X1+7Zctko+HV1np3fIWq2SylIqJ09QtHzHz5HVU8mE3FSVuh5jSrz7BE6So87nKvl2NLZXsg5MkE9wlWG5su2prd5nc3l18BBmVei4QH59PIXSx8nWOI/1fS0dZ9bwfkPkM2qm5/MJHJUkFW9ynzDQEjJOvhe5NFTYNV6Fz/ZU8dYtHmtjS5JK81S+81kVz26nom8JdgEAAACgMSrL+Fxl338gldWcmHhdFCmLf+Vt4raDmSr9Pt5tHU9sRYdVvveryrILFi95XRAui19Tt/VKv3tDpV8+J1uB4a5mnwD7uj5NKkNdSSopUOm381T8yUmab8vA8tfkqvmgad24U016Vib5dukqndJJxYZWjzomKCBmqMcg1/ZrvEpXTVZJpmtQ0EPWv0+Uz1XR8nLN8/KTVPJWpPM2YJ9RB+UTYqhVkLOw1knXat9utPweXVA1XCvNVNmaYSraULV+q1efBPn1HyqLoZiwLXOiChbFVTZUdy7Ks1X+7WQVr1roMrq3h7wHzpRP735V9itlq+zjYBWlubYtkP8zow3fLOSp7PMWKqptdKvHbSXlrFbpLx5KKUj2SeOCw2Vt2aqarzMKVZ7SSYWfu4dAlsgUBUSEubVJjvO7doyKv3YJo4L6yafvHHl3D616zkq3qGR2T5fbwqPl9+iSqp+bCmXbOVFFS+dWntuK/XYL9TxRXZXrO1Z+T8yU1ZAvl6dZVPixe5uRZWCaAjwFyKWZKtv0rEpcPnNLSKx8Bk6W97keanB4up6ABmIJGSrrmZUXvO3PVJW5/Z4GAAAAANSX353/lbVdZc3a0s0Lq4Sb3t1HymfQXGd4ajv0i4oTblPZ7ooylVVZ/JrK784keZ3X29lWtnutit7+h73/7A7yi/lMXmdfYu+0lduD1ZXjnevXxLtbtHwGzamcBK04X8VrHlXpt68bV5VXmz7yufJhWdtHSb5VQhuptFDlWVtVkjRVZT+uMfbWyH/CD/I6p5P9SVmxSr56QSVrnzSuVoVP5NPyuepRyWovtlf+204Vzu7stg7hroH1xp3yq7IPF6WFciZc3h5uA5ak0nSV/ruTinfZn9Ye0npWl+0sV66Vf3/XCYVcHMlU+eFs2bKypJah8moRIkuTqqNKpUyVLWmnIsfxVqjxXJQXShW1L6s7D5Jsu8aoYImxlqyngLae4W59/LpQRXPGqMzYrhD7iGzjZ1ChTufAc3BsvWm3/LpVMwradb++nj4vV3kqW9VCRRsqnh97uCsNkF/sKveJqFxVHJeXv2qbaZBwFwAAAAAAmJX1koGyBFTemm8ryFXZ/1a5rdNYeIyjTmdlHw9Uya5sY3Mlb0dtR0/1HeWYvGftQGew29BsX0eqOC3d2GzXJERe54bJ2n2orOeFVhPsFqo8bXyVYFcV5yKzmtGwFTUuqzsPkmz74lRcJdhtZH6NV3G8p2BX9vrEi2JUmlPNvd+1noNC2XZMVJEh2JWkso/G122/tfKX5Qxj27FaraJVhgmrXFUclzHYLa3mfQAAAAAAAJhQ2f9WqXTrYufSWINdEe56kqmSJeEqTkl1luSos8ItKk048ZP3lH3cSUVfJ8lWXShXnfJslX89TIUfrzb2OGSqZFFPFW/1MNlWjQpl2zFehfMnVhOaNgKlmSr/drwK5sSo1EOt2UqrVTxrmEoyawj8PXGc24J/z63m3K1W8fuTVXaojsFoebbKv56r8ipD0P1laV37RHd1tmuMij6Orz7gNcpdruIdJ/Z6BwAAAAAAgB3hrkeZKv08XAXzx6g0I7320PRIuso2jLFP8vP9yQm6ytZEqmB2jErqcryleSrPmKvi+cEqXFNdsFshU6UJnVS4aKJKf8muLEnhSXmebL/Eq2RRpxpCzZOkvFAqzFb53tUq/XKMCqe1U+GKuh7japUsClbhx3NVllPNSOYKR3Nuc+NU9FK4fWKoKqFtJVvOapUkhKtwTZzK/zD2SpbWQ6sMpq0P2/cxKpwdY68vXd3nXZyt8q3jVfjKMJX+mlXH8wgAAAAAAIDj6dSquduAvDpGy6tNP3k1C3C0FMj2S5LK98Wr7BfDyo2AJWSorO0HuByvpEMpKt+XotJdVSdaq7sQWbsMkNdF4ZWTp5VmqvynJJXtcJ/t/tTVQ9buA2S9yFEIW5IKd6r8p9X1OrdeHcfJeonLec1ZpbKfTvb11UPWPsNkPc9RH/i0+6wBAAAAAAAaL8JdAAAAAAAAADAhyjIAAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACVn+miybsREAAAAAAADAqavJM0SCpwJG7gIAAAAAAACACTFyFwAAAAAAAABMiJG7AAAAAAAAAGBChLsAAAAAAAAAYEKUZQCARojC9jCrI1MsxiYAAAAAQANh5C4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJiQZd26dRR2BAAAAAAAAACTsSQnJxPuAgAAAAAAAIDJWNavX0+4CwAAAAAAAAAmY/nqq68IdwEAAAAAAADAZCzffPMN4S4AAAAAAAAAmIxl48aNhLsAAAAAAAAAYDKWLVu2EO4CAAAAAAAAgMlYdu7cSbgLAAAAAAAAACZj+emnnwh3AQAAAAAAAMBkLL/88gvhLgAAAAAAAACYjOX3338n3AUAAAAAAAAAk7EcOnSIcBcAAAAAAAAATMZSUFBAuAsAAAAAAAAAJmMpLi4m3AUAAAAAAAAAk7GUlpYS7gIAAAAAAACAyVjKy8sJdwEAAAAAAADAZLyMDQAAAAAAAACAxs9is9kYuQsAAAAAAAAAJsPIXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAAAAAAAAMCGLzWazGRsbg99++03Dhg3T3r17jV2NwvDhw/Xyyy8bmwEAAAAAAADghGi04W5WVpb69Omjn3/+2djVKNx222169913jc3189MCDb5mhn5wbes6Ses/HaM2rm1Ho+xHLRgapRnbK5ua3bhIX8ZFqJnrensXKKrvDP3o2lZnfmrZqpnUsrOuH3iLrh8Soc6t/Iwrmd/BH/ThO3Fa/FGafsg+5Gz2a9lSzVrdpkWfjlNntw0AAAAAAACAhkNZhsbkojF65TFDPLh9hu6bf2yRq3RIyY8Mdwt21WykFr1oCHbrrUg52TnK2Z6sBc/fq8HhoWp3zWNas7fIuKJpHdoSp+GXD9Zjs5Pdgl1JKsrJUU52oVsbAAAAAAAA0NAIdxuZ9mMX6fkw97Yfnn9QC35yb6uLQ4lP6t6PXYPIzpqUME09rS5NDeWnD3Vv37/rsbXuQagpHVqjJ6PnKq3U2OGi28Vqb2wDAAAAAAAAGlCjLctwWtfcPfSZRnV/QMmubeeP04p1sepc12DWQ4mHzo8lasXYaiLIKmUZ2ij8n2E6322l6vymH5J/0I/ZOao6VrezJn2xQmMuMrabx775Ubr6ecPo6WadFTGgs86RpCP7tf+SqXr/3mrOLQAAAAAAANAAGm24e7o7lPiALr/nM7ewtNnwRdr4YoRqrWZb9qPmDoxSXIZLW+Sr2vrWkOrLMVQJd4do0e5XFeG2Ui2K9mnNlFt177J97u0dJmn9f+pRN/gkW/NAO937WeXzNmOXKfGxnrV/DgAAAAAAAEADoixDI9Us6mnNu9E9ij207AE9WmuZA3udXbdgt9kQLYqrIdg9XvzaqP+L65X4sGEEa0acFqS6N5nHPu3b5fq8vUaOINgFAAAAAADAyUe422g1U8RTr2qIWyJ7SJ/FPqnkGvLdQyse0Ci3OrvNNHLBq4po8GS3Uvux0zTSLf0s0oeJaa4NJtZJ7S80tgEAAAAAAAAnHuFuY9YsQq8uGOk+4vbQZxoV+5k85rs/LVDMv9wq9arzY8s0rbdbU8OzhunmUS3dmopS02Qo1gAAAAAAAACgHgh3G7ve07TkXkO12rUP6L4PctzbytL01DD3CdTUdZJeGX1yJvnqfGmYe8PefTIcMQAAAAAAAIB6INw1gc4PLVRsB/e2lEmjtGBvxbNDSn5klBa7DeeN0KuLx6i91bXtBDK+blGRh9HG+7Tg2nZq186xxDpGHRft05rpt+rq0Iq+UF1+/SjFLftBOa4zzBmVFenH5AWacddgXX55aOV+27VT6OWXa/BdcVqc+qOKyowbGqx9wGXbqzXDtX6xPtMol/22a9dOUfMZkwwAAAAAAIATz2Kz2WzGxsYgLy9Pzz//vHJzc41djcLll1+u0aNHG5sbzk8LNPiaqiNz1386RmcmPqDL7/lMlblnMw1560u9GnkUhXb3LlBU3xn60dkwRIt2v6oIt5Xqbt+iwbp6usvRdpik9f8ZI/cxyPu04FqX8PTGRdr9lPTYNaP0YTXDfPu/lq55g43TmRVp3+qnNPqBD/VjqaHLE7/OujnudU0b0MbzxGhrH1C7uz4ztlar/WPrlTjWMLoaAAAAAAAAaGCNNtzNyspSnz599PPPPxu7GoXbbrtN7777rrG5Qf04f7CinneLdxX+8CSdPWuGPnMZ0drs9mXa+lRP19Vqd1zD3SJ9dleoHljr0nTjIu2OM+7NEO4OidWkzDjN2G5YrYLfSC3ZMU1hrqOCy37Uh3cO12NfVx0XXJuWw+fpi+f6q5lxlDHhLgAAAAAAAEyAsgwm0n7sIj1vKGWb8pJ7sKuuk7Rs8lEGu8fb9jmKcw12JfWPMBy4J59VBrvNeozR85+u18aNG5W45FVNurG9ej44xj3Y1SElP+Ih2G3ZUzc/9qqW/GejNm7cqI3/WaJXHxuizoZhujnL7tXfH0muWi7iyqft223cqI0bl2ncRa6d/fWqs8++LLuVYBcAAAAAAAAnHuGuqbTUzfNqGk3bWZNePol1diUVZSzWrSPnyq0Krd9I3TbAYwEEj5rduEhfJkzSzV3bqGXLlmofNkRj4hK1zDA6dt+iGI362D2a7TlhmbamLNPzY4corENLtWzZUi07hGnI2Fe1YsdWLbu3s9v6hz4eVXVyOr9m9u0cy5lu59NPzVz6WrZsqWZNXPsBAAAAAACAE6PRlmWg5m71DlWpsWsX8dpWLRp8FHV2XdWjLEPRwRzt2/UffbposeaurdxDhfAZG/X+P1sam6uWZVA1pRc8OfKZ7u3ygNa4NHV+LFErxrZ3afGsSnkLvyFatPlVRXgMaY3HWPfzAgAAAAAAADSkRhvuogZlyXosdJQ+dJs8rJlGfrhV03q7th2FKuHu8dHsxkX6Mi5CniNnY3Aq+d2+TOl1qBec88GtunxSSmWDxwnbqlH2o+YOjFKcy+vWPYAm3AUAAAAAAEDjQFkG0zmk5EceMAS79vbFYx5QcpUCsidP+1GL9OWL1QW7nvUPqz3YlYqUutYl2JXUf9zIugW7kmRtrzHj3OPZlLWpVUZCAwAAAAAAAI0Z4a7JHEp8Uvca6sw6HfpMD0zzMEHYCeWnzjdO0qJ16UqcHKFmtZVXcNNeF7tNXladVKW4TdgWoesi617TV5L8+lyrcNeGtSlKdX0OAAAAAAAANHKEu2ayd4FiDLV2mzVzHxd76ON79WTi8Yh32yj8nzfr5josYx57VfPmLFHixq1K352uFXFjFHHh0YWtdueoZV2G+eb8pv2uz1t21kUe6+XWoGV7XeR2iPv1m2FeNQAAAAAAAKAxo+auWZT9qAVDozRju0tbs5Fati5c74Xfq8/cagqE6/mU93VzK9e2WlSpuXsiasseYz1b47EeTb1dJ+Nrt9ekdYkac6H7WlXXq+MxAgAAAAAAAA2Mkbsm8cPLo92DXTXTza9OUs8W/fX0a0PkPk42RY/dvUD73NoAAAAAAAAAnEoabbj722+/6aqrrtIFF1zQKJeHHnrIeMgN5tDaBxQzzz2qbXb7Ij0fYY90m0U9rXk3GuoZbJ+h++ZXjsMFAAAAAAAAcGpptOFuWVmZfv75Z+3bt69RLn/88YfxkBvGoWQ9GfuZ+yRpHWK1bHJPl4ZminjmVd1syHd/eP4+zf3Jve2U0KSZznF9fvDPY5hE7if96Cy1IEnn65yWrs8BAAAAAACAxq3RhruQpBx9eO8ofeaWXHbWpLnj1N7q2iapSYSeXzBS7vnuj4q7M04/lLk1mp9xMrScH/TTEZfndWGclM2vjc4/2knZAAAAAAAAgJOIcLcR+3H+KD2W6t4WPmORxlzk3ubUe5qWPdzevW3/XMU8niy3+dZMr7N6Xen6PFmfrz26d5iz9jOluDZc2UudXZ8DAAAAAAAAjVyjDXetVqsuuOACtWnTplEuZ511lvGQj6+f5uq+539wa/IbMk+v/7Pm2gHtx76uSV3d2w4te0CPrj36wgWNl5/CIsPdWtbMXVz3CeTKftB7r7tFuwqPDDNMSgcAAAAAAAA0bhabzWYzNuIkO5SsB/5uKMfQbIgWffmqIgx1dT36aYEGXzNDbtFwbdvvXaCovjNUOQXbEC3a/aoi3FY63vZpwbVXa4az9u1RvOaRz3Rvlwe0xqWp82OJWjHWMHLZgx/nD1aUW3DeX6/umKchHssy1OMYAQAAAAAAgAbUaEfunr4OKXnaA4Y6u81086svVB/MGl00RotmuI9s1aHPNMo4MZuZNRmiSZPdCyn88HyUhr+SpkPV1RguO6S0l4zBrtR58qRqgl0AAAAAAACg8SLcbWRyPrhPoz52j2DbP7xMz0ccXdGAlv98Xa9GGhrXPqD7PsgxNJpXm1FLtOhG98Q7bfZwdQ8frsfmf6bUjBzl5OQoJyNVn81/TMPDu2v4PPdgt9mNi7RkVBu3NgAAAAAAAMAMCHcbk58WaNQk91qw6jpJr9eh1EBVzTQkbpGGGEb7pkwapQV73dvMq5kiXlym5680vMmcNH34/AOKufZyXX755br82hg98PyHSjPk2i0HvKovXoxQXQdEAwAAAAAAAI0J4W5jUfaj5o4z1MlVZ02aNUbtrW6NddcsQk/HDTGElz9oxv0L9GN1pQvMxtpeN7+3Uevn3Kz23sbOavh11s1z1uvLOUPU8ljPLQAAAAAAAHCSEe42CkVKfny44pyTdtlFvLZEYy50bztazSJf0KvDDWNTt8/QffMrp04zPz+1GfC8EtPTlbhoksZEdlbLlu5lLPxatlf4kHF6/tONSk9foecHtNHRFboAAAAAAAAAGheLzWazGRsBAAAAAAAAAI0bI3cBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQAAAAAAAMCECHcBAAAAAAAAwIQIdwEAAAAAAADAhCy5ubk2YyMAAAAAAAAAoHFj5C4AAAAAAAAAmJAl48c9jNwFAAAAAAAAAJNh5C4AAAAAAAAAmJDFZrMxchcAAAAAAAAATIaRuwAAAAAAAABgQozcBQAAAADAoaioyNgEHDM/Pz+351xfONmM1yTMj5G7AAAAAAAAAGBChLsAAAAAAAAAYEKEuwAAAAAAAABgQoS7AAAAAAAAAGBCDTKhGgXCAQDAyWScKIK/TY6d8VyK83nMPJ1LAI0Pv+NwPBl/93N94WQzXpMwP0buAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAAAA4eVbEyGKx1G3xaaHgNj018Mm5Svqx0LgnAACA0w7hLgAAAABzKM1T9v4tWv3MeEVeHKBODyUpz7gOAADAacRis9lsxsb6KioqMjYBAACcMH5+fm7P+dvk2BnPpTifx8zTuYRj5O6QePvjtv00+poQ4xqVstO06ostynYZtNvjxZ1KezjUdS2gXo7pd9z+pYq+ebb2GNtrEfViiqZeaWw9zX09VeGPJErtJijh/REKNvbXRfqbGjT6XeVKUueHtPKt4QoyruMmVVPDY5Vo/EwqjkVRikuZqjDDVnVh/N1/TNeXB8W5GVqT8KYSVm5XRk6+o9VHQRd2U/+YCYqJ6qAgX8NG9ZS1JFrD5u4xfDZZWnrrMM3eLbUdl6D4mGP6xE4DxUp8JEJTv5YkHw2bvV6xvYzruGuo8228JmF+jNwFAAAA0DhcFasF8xdUv3yWpqzDB5XyeA/nJlseGa+F+932AuA0t2llvD3YlaQfluqj3e79placpRVPD1LEoDv03NupLsGuJJUod+9mLZ1xhwZFjtWbOw679OGkyvlMy76ueFKihGWJKnZfAzhmhLsAAAAAzMO7ucJmrNacPhUNSYr/PNt9HeAkinhypVaurNsSe5lxa9RbWapWflwiqZcirvaRlKX4hE3Gtczp8CbF3TpMz/3HHl23HfKQZr2/RslfpSglJUUp/12j+Okj1CVQUtkOvTt2hGZsIEJsDLISl2uHJJ+rI9RLktYv02c5xrWAY9N4yjJ8v1hTPsowttau43A9E3Op48k2LZ6yTBmSWvz9fj10TUvDyg1v25IpWrZLUlCE7n8wUkdzBM5tj0GHm57RyG7GVgAATk/G282O6W8TSB7OpTifx8zTuYShLMPIVbK9N8C4hkfZb0UqeGyS/clRbAfU5ph+x7mUZaDUQj3VsyxD8ReTFPFUstRuguIfPKDbJySoRFGK+2qqwqzGtSuYoSxDrlaMH6Tn0iSpg26fP1t3d2lqXMnucKqmDotVYr6kwGF68/NYda32vdddQ5UJOPXt0ZvDovVulhT1YoK6vDtML/9Q+/lqqPNtvCZhfozcReNUnq+ML+brkx3GDgAAAEBqdXYrY5MkKfOlTrJYLLJYYrRahUp/O0Y9Ayz2toB2ihy7UFtc72KWpPx0rX4pRj3btHBsa5HFEqDgvw3U5Le3KK/UsL4nhZla/dIYRYYa9hEaqYm17aM0T1venqyBfwtWC5+KbS1qERqpMS+tVqZLjWGPft+ihU8OrHr8ddq+UJmfx2lMv04KrjhPFossLYLVc8hkLdxahynrnO89WAHHcP4Kf1ytuLGR6hQcUPn6Pi1q3/7HOHWqWL9rnDKN/TgNFSv5i2RJUtCVvdS2+9Ua4iNJiVqRZO4RrMVfz9LMNEnyUdRzNQS7ktQ0TFNfGqZAScpP0OzlziIVOBnS12hFliSFKbxnsPpFdZEk7Vn+2VHX6QY8aZThbuDFkbrhxhvqtoRfYNz8FNBavYzvs5blilPpNBzarMUvvKDFX+7T4XJjJwAAACBt+c4xaldSaKjnSdi2vBSuTqPitaUi3CzMVNIXmfIPrFwne+kwBTftpIGPxGvLftcgs1DZm1br2VE91aJlpObuqD4hzV46TMEB7TTwkYVK2mXYx64kxY3qqRZthineU23g/fEa1qaFeo56Vqs3ZbsFmXm7krTwkYFq1zRckzd4Dlnz1o1Xu5Y9NeaZ1VWPv7btS9O1MCpY7QZM1MJ16W4T1SkvW1tWPKsxPVoo+OZ4ZVcTsOZtmKzwphXvPVuVu3A5f9W9d0npb0Uq+OKBmvhWktJdD6A0r07bnzpytWJ8uMLDwxUeHqtED6VSizfMUFS4fZ0Jn1aEdVlaequ9berXknKSNXv8MF3tWC88KlqT3kpVVk25ZnGWUpdM1R3XRzle375EjZigGR9t1+Ey4wb2EYXh4eEKfzpVKjus7R9N1R2DrnbbdnZSVvU1Ratsc7WGjZ+t5P3VblE3zrqmwRoc1UGy9lbU9T6SpOQPP6usw2s6uVrzQaJKJCk4WndcXUOwW+HSaN3Z0UdBHcPUtbmH0cJlh7X9oxmaMMLlc796kO54eqlS6/s5uClW1oalmjpqkKKuqry+wqOiNeH5Zdru4VeTvp5qX+fWpcqS41qJqtj2ag0aNUMr/s/DD4lD8f5kzR4frUFXu1zP19+hqUvq+LPgci2HR9Vhu1o4a0Bf3V8RTaSgyMH20gxZ8UrYbFwbOHqNMtz1Ce6sXj171W0JaeGy5aUa+cwzeuaZZ05KSYbjJ1AhxvdZy9KuuXEfJlaQp5wjxkYAAADA4feFmvxcRZ1df/W7KtSwgiTFa/IjW6RWA/TE6p3KytqttEWxip0xXhVrp7/UU8HRy2XfUyuFPThHa7dnKSsrSzuT5ij2Ssfo4Lwkje8arjgPJdTyVsQo1LkPKfSuOVr13W5lZWVp93cLKveRvVwxXccr1TUk/XGuwkNitNyxcasrYzUnaaeysrKUtX2t5jwYplaSVJqqZ8P6VX39vHjF9JtrH7HqHarRr69SWqb9+Hd/l6CZIx3vtDRVz/Ydp9WGEctbnhyoMV/Yk5VW1z2hJRWvnbVTa1+PVVjFoX8Uo36vVB0XW/j5GIWGPet8T6EjZyrB+d5dXj97uWK6xmi1McTZNFkDxyYpT7J/Tu+t1c4sD+c/e7lirjvVR+YGafBTj9sDH6Xq2ZnJcouuDqdqxpMrlS9JV07VjOuDXHvtdi/V2KGTtDQtyx4CSlL+HiW/HathQ2co1VMWtnepxkYOU+zcRMPEXFL+3s1a+crd6n/XUu3xEPBKksoytPSe/rr7lURl5DpfVfl7N2vp5GEa8ZaH0otle7T0LuM2JcpKW6pJN4/V0npMflZR11TthmtIO3tb1xui7bezm3litSOblJxmfxh0TV+1NfZ7FKwRi9Zr5aI4TbjGcOt+TqImXd9fd7+yUpv3unzuJbnK+M9sxd48VJO+OA5ReNkeLb0rQsMemq3EXbnKd72O8vdo82cv6+4Bd2jpXpd2N78r+dkR9mvFeZglyt21Us/d3l9jP6g67jX3i0kadPMkLU3bI5dLUvk5GUqcG6thg6ZW87OwQhMGOX4W3Das2G6SEo+lRq6zBrQUNTBKvpIU1F+DrxYTq+G4aZThLgAAAAC4K1RedrqS5o1Rp+AxWl3R3HeOJldb33SAlqSv0vTrQtWqVYh63DlTM0c4AsP9CzX+kS2O9Xpo+vZMpbw8Tv26tFKrVq0U2necZn6Vpaz4obKPo9iiiWGGcLY0SZNvjreHk+qhJ1IPauf8cRrQO0StWrVSSO/RmvlVulaNdIzEyJur2HkVMXC2Fo6t3F+PGTuV+dVMjesbqlatWqlVl34a93KKMrdPVw/J/vrXTdYWl9fPfHum4zyE6InUnVpw7wD1aGs//pDeQxX73k7tnOEIWAvjFbfUdeK51Zr7nCMuHbxE6aunK7ritVuFqt+9M5WSnqBof/sq6U/GKdVla5WmamL0Qkeo3VzRyw9q53uxGup87/bXP7g82n7+8uI18LaKc2W3+vVnHYGt43Ma2U+hrVzPf7oSRjgOYMdkxTlnmndoH6udNptsNpts22Plefy2ibQcrBnT7FVdS76YodnOibAOK/WVpxz1U6MUNyVKnsZtJr4xWzvKgjVoerx9gq2vkhU/fZA92Mxdqdh/LVWW6wZl2xV312ztKJN0/iDNeH+NklPsE3Mlr4nXjIGOQHDXbM2vrqTBF29q9g+B6jVullb+17Htylm6vaO9O+vtmVrmFogVK/X5sZq9S5KCFDEpXmuSU5SSkqyVsycoLChDs99IdN3gKOzRZ8vtYV+X6/tX1uptN0TD28ncE6vtzdT/HA97d+lg6DxKR1I1I2aqknMlWbtoxMsVn0GKkpfGaURnScpV8lP/rPdkbNtfG6vZP0iS/bqseJ2U5DWV16bsn7nHV9q9VLNX5SrwbxM0a2Wy8/qa8Df77Rc7Zo11P8YjiZr5VLLyJXW4rfKaTPkqWStn364uVkn5iXrsDcN1cCRVM+56TpvzJQX20u0vO36GUpK15v0ZGnS+pPxkTY2J0/bqvuioRnHSCiVKks8wDXIWbfZV1MAo+0MmVsNxQLgLAAAAoHFYPNClZqxxCVCL4E6KvG+h0isCzubRWrV8tH10qwf+EyYrupo73FJfHK+Kwg4D4pP0RBdHiGjQakSCVk+oDGdnfuRSeOCjOM11PA2ZskTT+3h6seYa8PpcDfX2V6vze6iFMu2lC76ervHrHKsMXqKkx0Pl6Qj8uzyhhBcdAe2eZzXz88q+zPSKcLqVgqs5CaETJmuA/NXq/FY6+LtLtLc/UzsdD/3Pb+UIsA2aD9XEB1tJzVupVdtCZblkw9lvT9ZcR1Ib8niSltzocQ9qfuNCLRnpeGcrpmvhjxU92cpMdzz0D1Erl1IZlZpr6EOxauXdXK3OD1Gh6wE0YomPuNx6XtPytFtcLklqes1Ux2Re+Vr55GxtL5MOfx2np/6TLylQUU/GKsxTsivZ+198R5P6tZWvVZLVV237TdLS2Y7aq7veULzLLeDFSUuVkC9JYZq6cJIi2jW1jyqU5Nu0rSKeeEsPdbY///6nqiMkK/R69APNjumtoCb2575BvXX3a487Jh/bodStLuFbzmd6e1W+o27sUs0Y1FZNfSXJV0G9Rihu6bFNWiZJ2pag+CxJ6qXBka4jm4PV/3p7jdOSj1cq9SjDuUbhjyxHSYm2uqRuw3artWfJTK3Ml6QOmvD+fE3oU/EZSL4XhmnCG/Ga0FH2a/Clt4+9JuyRRC1dZh9uGzbNfl1WvI58m6ptv0l660H756LtmdW/TscJmv/yCPUOsm/sG9RbI15+Vw/1lP0YF7qU2/h+k5IlKeh2Tbqn8pqU1VdBve7WnCfDJJ8gNd3xg1zHlLudk7dm6+4+jp8h+appuwhNWhinqMBjqV+cqzWf2mtA+1wfpd6uk9qFDdIwH0naoeWJbl+7AEftFAt3t2nxlCmaMmWKXv7C/auPbUvs7VOWbJMk5e/fqE/eeFkzpjnap83Qy298oo37jbMrGJTna9+GTzQvboamOV5r2oyXNW/5ZuVUU4fqhCrP0LJnHO9p0UYVGPtd5SVrnuM9zFt30Ngr5e/TxuXz9PKMafb9TZmiaTNe0Kwla7Wt2jebo7WvuH4Gpcr5fq0Wz5mhaU85juuZGdXsw/H5zUlWxdFkfOTYZspi2T85h9IcbftiqdvnMOWpaZox522t3LBP+dTqBQAAOKW1unGmUv5viQZ4zhQlSf3C7GNeq9qiVc6QNlrjbqphJ5LCxo5zjgpdvsI5ZlirnY9DNe42T6UhHAKjlVBSoKx9aVo1IUz+krasTnDWp42+yzG6tRohN452lpKI/7jy9Tv9rZ/jUarG3zheq3fkudS8dQiM1ipbgbL2ZSntcZfzcX5P9XRkroXzhmnYvC3K9vBPoR4zsmQ7mKWs9AUa6gyQC5W0oiIaD9W4UdWdZ0ny14Cbhjoep2vJioriCq3U828VBzBXw26eqy37qxy91Hu6skoOKmvfTi24qZoE+5TSVFFTHAFnfoKmvvym4p5IVL6kwOFxmnpltcmudPVETfLQ79vrTo3tLEkl+mz9dmf7nv0HFBTkI13ZXxFVN5MUpPPa2B/l/vaHsdMhTIOv8VAiomlv9XaURchy2TZr7Up72YSgaEV7qhvbNEp33+Zhf3WwKfEzezmKqwerv2EXzhqnp8DEavWTpdT19iDR58bxGnGhsV+Sta1GPDRCQZKUtUJrKr6EOVp7s3SgZZB8FKb+V3j4rCUFBbe2P8jNkucrLFi3Tx6htq6hqCRZgzX8zmHykaQf1uubirw1wM/elrtea3ZUrb3ge02cUtav1Mp3b1fl+OcMrfnUfk6Cbpvk+Zw0DdPdt9tT9R2frnEfAV+TnCStSJOkYEXf0NW9z6UeNBOrob5OsXC3Lkp1IGmW4t5cqc2/HFRBRb5YWqCDv2zWyjfjNCvpgGEbhyMZWhb3guav2qwDeQWq3PSgDmz5RLOenaX1vxm2OdG8Oujyno6vvTPTlHbIuEKlnO/SZH+nrdWtl2vtYunA1/M146X5WrnlgA46T5JUWpCvnF3JWjbrWc1atbvm8FiHtW3pi5r1UbIyfi1QaUXgWlxg38ecF7X0+5r34NGRbVr84iwt+3Kn2+eg8lIV/LpbG1fN1wtxy5RB3V4AQKNRrMO5hz3fcoiGV3xYuYc5+6bQtp9G3zW6miVWM99LUELSTh0ssClreazCzjbuwF1wc09jYSWVZiqzYhBo337q6W3oN+oSXjmacHumo5RApjIrcjL/fgpvX7FCXWQrbYOzSq9CWmYrO7uGJTDYGe7qu3Rn7dlWN41zlk3Qprka2LWFAnxaqFO/GE2el6T0PA9hqVOYxj1Zsdc8Lb+vp4KbWhQQ3FMDH4rT8k3ZKjSOxXDaopQvKh6HKjjQwzG7Li1DnMe/ZUtlUhR232Rne97H49WzTYAsAcHqOWSi4j7yHDabQcSTK7VyZR2WB+1xYxVNozT1uSj5SMr6+F0llkgKvl1z/mUIhwx69entHHnrLki9r7IHUyVpO5zBVIc739HKleuV8qKjDqiKdTg3V1npqUpcMluxo6IU+x+X3XhyYTd1qBgdWQd7MhzjJS/r5hKuuevQPdzYVDuXuqYR/SKqngdnjVOzT6xWT2X/0/8cdYf7Xtbb2FupS2/ZP4Vc7cg4xrMVerve+XSl1qfEKcpxjRQfzlVuVoZS/7NUsx+5Q1GP11KCwydc4Y4vCaro3kt9JUmbtbniNoQuURoSKEl7tHRsf1198wRNfWuZNu0+rOLqRmzn/k87HG+xppIXwV162YPj3Zu0vY5Zg7MGdHCU+nt4H8560Eyshno6/cLdvWv0zroclTZprV7/GKl7H35Uj95/p27o00YBXrKPNF33qZKNxf6Vo7ULFmubIywNuOBy3TDqfj368KO6NyZSnYK8pdIc5Rzj773jqU3v3rJHtQf0/WYPI3IlSfu0cauj76LL1LNZZc/BlPl6a80+FZRLcj1PD9+vO2+MUIcgb/t52vC23vqimiBc0uHvPtCynQXyDuqgiBvv1P0PP6pH7xmpyC4t5C1J5QXa+ekqZThH2XbW8Icf1aN3XO44fqndwEf16MOP6tGHh8t+R9BBffPeMmUUSPIKVIe/36A777evc/+oG3T5BQH2DQ9t09KPt9USPgMAcGLsWTJW/Z9PlYe5qnEiFKXq2QFja5iwBY3GVbFaMH9BNctMxY4cqqF9Q1VdZusuVKEXG9sc9mSqoqCBzg+utqxDpVD1dNw9rB3psseT6Urf4WhrH1KHfbgqUJ7zRsN0PRsWrODgmpYYLa9Y3fn69rIJC7+bowGuL16ap/R18Xr2vkh1ahEgS0BPDXxyobb87rKOQ+jDKVr1oPuI48LsLVr9ykQN+1uwAnwsCu43RnGfO0pJOB3UQWfDcsVUOV7DEvZs5TGnV4bT6viEUlbHKtQ1XC/M1pYVcZp4sz1strSJ1JiXViuzppy6kfENDFJQUB0W5z3qVTW9OlbTKgZmK1AjnrxbHYyjFw3OOqv6Ea9+FdvmHXKfqK04S8lvxeqOQVcrPDxC/QcN0rDRsZo6d6lSd9UhXbf6ys/YVq0s7XEMTww65yxjZ6XzQ+o4YVglZ11TSclPRVQtgREeoanrHSuYcWK1NhXnZI/+V58hnllZjnIEbRVS40luqxBHGPlHfn3+eilWVtKbih01SFeHhyui/yANGnaHYp+eraVfZ7hPsOZJm9aq9js8a+WVd7jA8QWutati33pcYY4fhZL9m5X49su6/9b+irjqag0bP1sr0nPdv2w/ku8cNVxjSZV7ExwTFWbpjzrlPpU1oJX1rqKN+wsPV/it7zq+bGFiNdRPowx3D345y1kGoPrlZa391bhlHRQUqKDlFRr78L264eoOat0sUIEt26nXwLF6cGDFVykH9H2ae1mHgk0rlexoatFnrB65a5B6hbRUYLNAte4YoRH/itUNHR3BYr1laFmV91v9YixBoXN7qqejevyB9G3OEgduft6mHxxBdYeeveQ88iMb9cmaffbRsMbz1Kyl2vWM1EiX95rz9af6pkoQbldaUKCAjjco9l8jFdmznVo2C1TgeR0UcctDGnWZY3Rx8bbKb9nkrYBmgQoMrDyP3k0CFdgsUIHNAuyB8KGd+v4Xe1+b6+7XyGt6qV1L+zotQ3pp0F0P6oaL7X8dlu5K07EMDAYA4Pgp1p737lD03MO6/U7PE+DgBGgapbtjDmv2rQS8OA5atXJO1HRy/tR0f1X/LuO0al+BdibNUWzfUDU3jkIu3KLVz4xRz5btNH6d4Q937+Ya8PJOFWSlKeHFaPVoVTU1z163UBMHtFNwlEut4/owhLTNr5upnYezlPbhTEX3blW17vD+JC18ZKDaBUdq4S5j5yksJ1kJFZUvlK+lSxLdQ9lj1dS/clTr4VRNHTRMk95OVUZuiWQNVNCFvRQ1ZIQmPDlDs95foxnXum/eOFXWNa0bE06sdn4H50jnTTtcq8XWLGvJWEWNmKCpb6WehNHKh5X69CANm/yuUnflqkRSYMu26nXNII0YN1UzZsdrzXTHpGL11NTX5YuSCwcrbmWyVi6coQlDeinYXvlAUomy0pbqudGDNPTZ1OPz81QTZw3oOmJiNdRDowx3G5a3Lo26Vm2Mf/RICuhdeWvIwVzXSDRfaZscX+35XqpB17WxB42uvALVa2h/tWsUZ7SFLncWOEpTmocQPOPbzcqX/f1cXjEKQdLBjWnaXS5JrRV5m+fzZH+vkWonSeUH9PU3+4xrOLRW+MBeCvRwTtr0vtQ5Ojcvrw7fCFcoKKzlj+gA9erTWQG+AQoMLFT+if8/GAAAdmWHlfryCEW/kaHg257T3TWU40TD63DXc7r9nB2aPWKY4jY0+D/p0Ni1DZGzSuz+LNU+TVe60ipG6Z7dwlEf12U074+ZddiHgTPFjNYqm022Oi9LNMB9T5K3v0L7jtPMpJ06WFKgg5lpWvX6E4ru6zqkN1Nzh05WqoeA1r9VDw19eInSsgpkO3xQO5OWaOaDA9TDJWnN+2KMxr9d+S6dwzG6zNTuKsdYw7I91lm/2Mm/lXrcFKsl32WpoKRAB7ev1ZIXYzWgu+sBJGnMfQuP/jybUq5WTHtOmyXp/Lb2LxO+nqpJn9byj5saRkEWVfQFVIy0LVbijFgl5ksKjNDjS5OV8lWiVi6dramPTdCIayPUu13ToxiVWxfBausYLZqbVf0doPrjgBzjeerGWdfUR8NmpyglpfrlHUc935KPlyuxjrfWNwrWXgp3lJXI/WJdHeuz5uqb5B3K37tZiev32O8eCg52ZB57lFnjTvY4yzcE1zTKugbFXzynWMdEgBGT4pX8VYoSP43X7GmTNCEmShG9XCZYq05ZcfV3PZVV9vhWuVB9FRQaoRGPzVbC+hQlr4nXrIm3K8xRTzd31WOaXzGpj1WOLzza6u6lVa+Zqku8Rpxf+UrVcdaA7vyQVlbZh8vy34pJBHdo6Sc1fihAtTzEbidf4MWRuuHGG2pZrlWoSymBurtAbTwVyJYkr8oks7TINUL8WT9XfOPSvpM6VHfWAnqp20XGxmPRWr2qvN/ql2u7VJ1aNqBrN0fQfFCbNhnC1/IMfe/42j2w++Uu76dAGZmO/8k2C1Zrr3zlH6pmKWmplo5ZH/L3/WwPio0C2+ii6maG8HJ+faaCgprjWjct2+gCx/8A9n3+iub9e602Z+ZU1k6WpA7DNWnKJD366L2KPM+lHQCAE6R4f7Jm3NpfscuypOARenzY2crNzT0+S95pdNNemb3+Y5VzcCzLn2frpsdGKFhZSniov6KfT1bWaXQqYeAdopCK3HNdktI8BJ5uNqXIOYiya4gjnAxRSEUJ1MJ07awxdSxU/BCLAoKDFfy3Z7VFIQoPqwguk5RyXAcR+qt52x4acO90LUnKki0nRbEVIXTeQi3ZYFjdKLC5QvtGK/blVUorKNDu+QOcOXTS0lWOcLWnwuzFLqUdSUr1+I+BY+Ttr+Zd+in64ZlataVABf+3QAMqDmBdvFbVeJ5PDbmfPqXn0iSpgya89I7i7rEnoptfeEorahjZt+7b6i6kXG36yh4aBV3WyzHyfLNSHGUKeo2bqMEXekrZMvT9t8a2+rnkUked4eTN2lRNGJ21Y7Pj9ve6cdY19RmiqGrKGFfoEDXY8f6TlfyNmf4n4Kuo4Y4JxLLi9c762r+kLN7whub/YH/c5fr+9vdtvUSXOMaBVX+9SNqconWSpCB1ON/TtVG7zamO0dQ9J2jioLby9VBWJGNHDccgSXu/0abqrvmtmx3HGKbe3exNGe/doUGDrlb406luq/o2baveQ+9W3PuzNMxHkkq0eYcj5Anuoq4+krRH33xbyxcodeVSA7rXoH72yemq0yRC/R3BfdZ/1tQxuAfcVRdTnlQ+wZ3Vq2evWpZOan1MVRC85X+02/2apSxHXdgWZ7c09rq5INh9YrJjE6iQKu+3+qVTsIc3FNBLV3Sxh9X5O7fJLd7dsVHbiiUpUJ17OKY/ley1syoGLB/arMUvvaAXql3e1saKu7qqG03rG1D1tqr68uqggdd3sI8UKC/QgR3J+mTRLM2YNkXTXpqnxauSlZFV21/nAAA0oKxlGj9iklZW3P6ftVT3Xz9IgwYdp2V8Qt1naTa7rATdbXz/9Vn+tdR57vZ8NknD7l92+pxLGPTQwJsq/lKN19yPqqkz5pC0KM45YnTA4Mpxs+F9K4qiJilusbOSbFX5y5WwQirMzlZ2U3tZh7AbRzv+Vs5W3CJndOzZpslqZ7GoRZtgBY9d7ahssEULh/RUcJsWsvSrYUTr2WEad1vFrQOFKnRkQtmfT1RkaLCCA1po/NeuG7jyV8hd4zS04unhPMff/a008MaKKeZWa47LiF5P8pYOlMUSoOA2wQp/yVF9N3u1JvbrpODgALW43z2IceXffrTG3VTx7KBqnCPuVJCzQk+9YJ9Zqe09UzXiQl+1jZmq24MlabOem7ai2tvrSz6e47H0zOGvZzlCvmANjqo6YVRRNRnn4aR4xVf3YscoKDxKvSSpJEFzPvAQYx1O1ZvvemivVmVdU5/ro1TzlHOS2vVXlKOuiukmVut1pyb2lKQSJT4+Qe/uruaDk6ScRE19cqV9EFbw7Zo4tCJeDFbY1fYTUN31orI9WjrXUV82eLD61/fOo6Jqxt7mJSt+aW2fwA7Nn+ehhILrMV7dXxGOCdvOPrOpcnNLpP8sU2KVjdydFVgx3LeroobaB8zteGuWUqvZbs970Qq/KkqDhsUptYZTL7ca0L0UdVWN0a4kX0X0i7A/ZGI1HKNGGe7i+OjQtbOjTu0P2vZzRWuBNqc5avQEX64r3G4nyFFOzX/XepZ/0HNd3wYS0G2k7h8ZqQ4t3WtGlB46oIwNa7X49Wma8tw8rdxxPIcQAABQR8HDNX9tgh7v5/hjPjBKcWs83IZ3rMv7I5z1Pk95549QvPH912NZ82KU7P98C9ag6QlKfmP46XMuUUXYI3NUEc2uju6nZ3d4Tg2zlw7TsHmOvubjNNkZCkutRsQq2vE0/ZFh1ewjT6vvG+ecEG3oXTH2ydf6jlOs4xb1wnnDNGxpNQFpYbqevfdZZUrK25+t8L79HKFwqAK8tyh7f560bqJiV1Tzh3xpuuLfq5jObKj6XWV/1KqNvzJ3ZSu7ME9zH4yrtp5u3sdLnMcect0AZ0mFVneOd4a+qfcP0OQN1bz+78s17r7VkgqVvT9A/a5zJEWtQuSfma7s7ELlzR6nuOrq6eYt15KPHI/bDtCAGieBahyK8z3cPVDd4nY3RpaWPuooxxB8u6bGON6stYPufnKY/fdX2nN66IPqvpbK0Oy7JmjpZseEUcWHtSdphu54JFH5kgKvnajoimlmdIm62Ges1o5XHtKbm3NV7BhJW5yToRXPR6v/5MSjGkFbJy0Ha8Jt9t+8GXPHasKSTco9IknFOrw7WTNGO0pF1JWzrqmPhkTWGu1KaqshQx3n1XQTqwVp8HNxigqUpAy9eWuEop9fpk27Dzsn4io+vEepb8Vq0PVTlZwv++jvl90n42sbM1GDHPuYfWu0ZiTt0WHHDor3pmr2PdGavUuSAhX1YPRRT25X4ZJQxy0DP7ysh96o+JwlHclVxsoZih48SYl1uMDy/xOrO55N1h7H9sU5m7T0obGOY+ygCfdEOetIB10TrSgfSUrV1NEzlOx6bnI2aelDk5RQIklhGlzxd5qkrqOm2c9rfqJiR8Taf4Yqfh4O71Hy89GKfmOPVJYvv2sGK6zGwcwuNaB7RumK2rJdSb79BstefZiJ1XBsCHdPZR0vV69mkpSvzRsdgW7BTn3vGFTQOrSy7q1dkFpUVHgIGaRJzzyjZ+qyTBnurFV8ogR2iNDI+5/SU4/fq5EDI9TpvBYKcP0Fe+SANv57lpbu8DimGACAhuUbrMHTV2rltAgF5ifqsZnJVUed4MQ6nKy4JxKVHxihqZ8maFK/4MpJhXB6On+05rxYUXl3iyZ3DVH4Q3OVtCNb2dnZSl8Xr8kDghUcvVz22LKVRsfPVJjr+ILAAZr7YbSjBq+nfczVmNAWGrjYEXz2XaC5IyrC4VBN/3ymo/ZvnpZHByv4qomauy5d2dnZyt6zRavnjVGnpp00ueLOZbft/RU96QlH2Jqn+CEt1Om2OC3flGnfPjtTWz6KU0zXTprsqBcc8vhkRVf8vd9loqYPdjzeNFGd2oRr4rwkpWfbjz17R5LmPhSu0KHxjpHCAzT9IZchfIHRWvhZ5Xt/Nszw+juSFP/kQAW3HKZ4x9tvPmGJpjvn+wjVxBkVo6C3aGJosNu5y85OV9K8iQoPHaZ4R2Y+YMZEuQ0i/DFOnSwWWSwWWbrGqYax0ydU8tMe7hyobnG5GyPrg8cdgVWwbp/uHsjp0ljFDbd/eBmzHtfS/S59Dm2viVKH/M2aPWGQIsLDFR7RX9GTV9r333mC5j8R5jKxZ5CGP3S74wuuDL07YZAirgpXeHi4Iq6/Q899tke6cIQev8vxge3Zc9zudOhw11ua2i/I/u/Uufdr0D/CFR4eof63TtLK/VJQ5w4138buwlnX1GeIoi419noWHDVU9ndlwonVmoZp6odxGuEoNbnns5d1/6397Z93eLgi+kcr9m3H5GlBYXro/fnOdZ2ahGnSkqmKCLKPgF05OVr9Ixzbj4jV0h8kKUhhj87X1CuPfSrYoKETdbtjMFnGexWfc7jC/zFId8xYqT1qqxGT7nZ8Fnu0x+MFFqaoawOVtWqSoh3bR1x/v2Z/ly8pSINenO3+/pqEKfbZQfbrZ/9KTXI9N27bTVWU61trGqapb01QF6uk3FT7z1DFz0P/aE36zFHWpN9UvXVXLemHswa01OuaK+p2LVvDNOhGR+lKJlbDMSDcrYtzgxXsOFMHf6uh8LuknN9P5BjW2rTR5d3t8W3pru+VUS7lb/7WPmGaVztd1sdYQqKFWpzpePjLPvdSDo2Ud5PW6tAnUiPueUiTpjyjpx6/V8P7tHFM8FCgnV+nea4HDADACRB0zQwlvBglvy9maPYGxmGcPMVKfeUpJTYdpLiEGYqqucoWTiOhD6dp5/wB9pG0ylbqK+MV2TVYwcHB6tQvRs9+7hhN22qAZqama8F1VYuONR+8RJmrxzlCVuM+xmuhY0Rq82sWaGfiaMdrOXSMVcr2ORrgaMz+Ok7j+3VScHCwgkN6auB9C50jaltd52H73tOVEj/U2Za+eKKG/a2dffvgdup580TFO16/1Z0JSnvaOY2cpOaKfi9FT/RxPM1OVdx9keoUbD/24K6RGv9Kqr3cg3eYZm5PULRhPg37e49VqCPwdnv9rpGKeWa1s1xE6INrlTmropSDXfMRS5QypaLN/dwFB3dS5H1xSnXsIOzFnUoYUd2EHqeA/Uv1+Cz7gJzA4VM9TsLZ9Z5plSMun6wsM+PU4W7N/3CGRvQMttdmleRzfi+NmJ6g5LdGqK2x5mno3Vr64QyN6NlWgc4+HwV1jNKElxOUvHSCBl/tqNG7a52+OV6BkzVIUdOXK2H6CPW6sHL+mIpjXf5Uf9Vpep0jiVruqGsaNGJw7SUZKrTsp8E97Q9NN7GaJDUP04SlyVq58HHdfmUHBblNweOjoI5hun3SO1rzaZyGt6vma8yWUZqxfKVmPRilDkGV8+EosK16DXlIs1YuV9z1xzpm18HaQXe/n6AZt/RSW5dj9AnqoKhxcUpYG68Jg65Wr2BJytC6rz2VaGiq/k8k6M0Ho9ShYh/WQHW49iG9uXqlJnkIn5teOUkrV7+ph641nBuX9+ZpO104QvPXJihunOGcWAPVtucgPTR7pZZPj1KQ8efIYM8nS+01oBWmwdfUKdqVJPWOGuL4uWViNRw9i81msxkb66uoupoqNfl+saZ8ZP+fWYu/36+HrjmWv7q3afGUZcrwsI9tS6ZomWPY/vBnRsrzl3qV26vjcD0TU7FWgTa/M0Of/CTJt4OGTxypSz2UuVV5hpY9u9hezzYoQvc/GKmjeRd1O8ajlJeseXFrdUDeujTmfgUnv6z//CLpohs06Y5elbPcOuxb9YLmb8iX5K0ONz2ikd2MazgUbNbbz69UVhN/Wdteq9hbLrWXgFCO1r4yS8m5tZyDX9fq5TnJOujhs3Lt63DTMxrpKI4uSTlpS7X0i5+VX3iOIh+5U5d7PLx8ffPGC/b3WdMxAABOWX5+7tMmH9PfJsfR4a+natgbl+hdE5ZUMJ5LNYLzedT2L1X0qP9pQsJUhXn499yJ4ulcQtKKGFmGxNsfj1wl23uVNW2PReZLndTukXRJoZr5fzsV2964hge/b9HCWZM19+0UbdnvGGbq3Vytuodr9H3TNXFkDzV3rwhWVX66lr84UdMN+wi9apjGTZmu0X1bVT8fRWmetiyeqcmvL1TK1mzlVZRIaN5KPa4arXFPT9To7jUEm87jT1P6/mzHSFvH9oMnavqkcRrQsbpXL1Tm53P17EsLtWpLurKd1RX81apjuGIemamJI3qoVXWbS1JhplbPflZxi1YpZVfl6/u3ClX44NGKfWScBrSvfgeFP67W3BfjtPDzNKVXnLuK7UfEauYjMerh6QB+jFOniycqXZK6zNTu7bHOshH1Ya7fcVlaeuswzd4ttR2XoPgYs/1f5tRn/N1vruvrJPl6qsIfSZQUpbiUqXL/Wgj1ZbwmYX6M3K2TAPXq4wgvizO0avlm5TsmWKtUoG1LlzkmKmtEml+uniGSVKofUj7R979Ikrcu7VM12JWkNldcqdZeklSqjJXLtNnjsNcCbVu+RrvLS1WQn69zLurgCHaPv4O/u3893LKJjw7m56ugdLe++bKaUdRH/qfdFZudGeiorQcAwMnT9MqpSjRhsHvKOH+E4hNPbrCLGgxeIpvNZl/qGexKUsjDOx37q2OwK0ln99Dop1cpbd/BymMpOais71Zp+p11CHYlKTBUQz3sY2fSAo2rKdiVPQTuced0rfouSwdLHNvabLIdzFLaZ9NrDnblevxZKqjYtmL792JrCHZlnzTtulgtSNqprIMu29oKlJW+VjPvrCXYlST/EA14eIHWpru/fkHWTq2dH1tjsCtJ/u0HKHb+Wu10PXcV27882nOwK0ntY7WzYv3jFOwCAGA2hLt11XGgbujouNl/1yeKe22pkncdUP6hfOVkbtQnc17Rsl3Hq75rvjLTNmvz0Sw7DzhmrjUKULdu9qr5pZm7dUCSfDurW0fjeg7Nr9D1VzrGuRZk6JOXZmj+qs3anZOv/EM52r0jWUtfebHyvbaM0KDenmLiejj3bOdI25wdG5WRk6/8QwUqlaSOV+tKx50NB7+epxlvrdRGx+eQf+iAMjas1PzXPlFGsSQFqNNl3dxD7O8Xa8qUKZoyZYpe/sLTfUXbtNjRP+WVtaq6Rm39AAAAAAAAwIlBuFtnAbp0xB269gJ7VFiau1Nrl8zTCy+9oFmLVmrzrwWSd0t1uthYx/ZYHNDmjz/RJ0ezrEmvtrZsQM8rdKlLqZ3AnpfXOAFa62vu0p19WtpH45YXaN+GT/T2rBf0wkuz9Pa/12pnrv0+Me9zL9edYxqi5MEFanee42HORi2e9YJeeOkdbTwkSS0VedsgdWhi7y74eaNWOj6HF16ap8WrNmrfEUkKUJv+d2hEl+McPAMAAAAAAACNBOHu0fBqrSvuekQPxUSqw7kB8q44e74BatkxUiNj79c/GuP9ll4d1K1ilgO11uVXtDGsYBSgdgPv1xOxIxXZpbVaNHG5D807QIHndlDETffrifGD1M4Rsh5fLXTFbXcqIiSw8hzrN2VVzMZw1uUa+fBDGtmvk1o3d/kcJHk3aaHWXSI1MvYRjb2ydWUHAAAAAAAAcIppPBOqAQAAHCfGiSL42+TYGc+lOJ/HzNO5BND48DsOx5Pxdz/XF0424zUJ82PkLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAAAAAAACYkMVms9mMjfVVVFRkbAIAADhh/Pz83J7zt8mxM55LcT6PmadzCQAAANQHI3cBAAAAAAAAwIQaZOQuAAAAAAAAAKBhMXIXAAAAAAAAAEyIcBcAAAAAAAAATIhwFwAAAAAAAABMiHAXAAAAAAAAAEyIcBcAAAAAAAAATIhwFwAAAAAAAABMiHAXAAAAAAAAAEyIcBcAAAAAAAAATIhwFwAAAAAAAABMiHAXAAAAAAAAAEzIkvHjHpuxEQAAAAAAAADQuDFyFwAAAAAAAABMyGKz2Ri5CwAAjtk72/fokeTvJUkvRnTTHV3bGlcBABxHffv2VXJystatW6eIiAhjt5s33nhDzz//vFJTUxUcHOzWt3v3bl1xxRXKzs52a5ekJk2a6Msvv1SvXr2MXQAAoI7uvPNOvfPOO3r77bd1xx13GLu1adMmxcbG6uuvv1Z5ebnOOOMMjRkzRk8//bSaNWvmtu6hQ4f05JNP6vXXX1dJSYlat26tmTNnMnIXAADUzyPJ3yvnSJFyjhQ5Q14AwMn37bff6vHHH1d143kOHDig33//XX369NFdd93ltowZM0ZnnXWWcRMAAHCcJCQk6PLLL9fWrVs1bdo0LV26VEOHDtVrr72mq666Svv373eu++eff2ro0KF6/fXXdc899+j999/XJZdcoujoaMJdAAAAADjVfPfddxo+fLjy8vKMXU579+5VWVmZnnnmGc2fP99tee2119S2LXdiAADQEP744w89/fTTatu2rbZt26bJkyfrn//8p9577z0tXrxY33//vRYsWOBc/4033lBSUpIWL16sWbNmKSYmRmvWrNEdd9xBuAsAAOrnhau7qmUTP7Vs4qcXru5q7AYAnEAFBQV68cUXdcUVV6ioqKhKKQZXGzduVIsWLXTeeecZuwAAQAM6cOCA/vjjD1155ZVVvkyNjIzUBRdcoOTkZOXn5+vPP//Uxx9/rF69eikqKsq5no+Pj0aPHk24CwAA6ufObiH6bcIQ/TZhiO7sFmLsBgCcQP/973/16KOP6tJLL9V///tfdejQwbiKJKmwsFA//fSTWrVqpZYtWxq76+T2229XRESENm/erIEDB8pqtcpisejvf/+7MjIylJubq3vuuUe+vr6yWCy67LLL9P33leV7bDabvvzyS4WFhclischisejcc8/VY489pr/++svttQAAOJV07dpV+/fv17vvvmvsquLAgQP6v//7P3Xo0EHNmzd367vooosIdwEAAADgVBEQEKAXXnhBX375pUJCqv/CLT8/X3v27FGzZs309NNPq0WLFrJYLDrvvPM0Z84clZSUGDfxKCMjQ4MGDVJJSYnee+89jRkzRl999ZWGDBmiyMhIZWRk6O2339b48eO1efNmRUdH67fffpMkrVixQn379lVhYaEWLFig999/X5deeqleeOEF3XnnnXU+BgAATiXbtm3Tvn37FBISojPOOEN//PGHDh06pM6dO8tisbitGxgYSLgLAAAAAKeKf/zjH3rkkUcUEBBg7HLzyy+/aN++fdqwYYO++uorzZw5U3PnzlWLFi00YcIEjRs3rk7halZWloYPH67PP/9cMTExmjdvnoYNG6b//e9/uvjii7VmzRrFxMRo9uzZeuyxx7Rz5059//33stls+ve//63g4GCtWLFCo0ePVkxMjFatWqWbb75ZO3fu1N69e40vBwDAKe3PP//UjBkzZLFYFB0dLYvFooMHD6q0tFRWq9W4uiwWC+EuAAAAAJxucnNzJUl33XWXvv32W40ePVr33XefvvnmG0VGRuqtt95SQkKCcbMqvL29dcsttzj/went7a2//e1vkqR//vOf8vHxca7bq1cv2Ww2FRQUONsOHjyoHTt2yGazSY76gf/+97+1Y8cOtW/f3rkeAACnuoKCAsXGxuqrr77ShAkT9I9//MO4ikeEuwAAoF4WfZ+ps177VGe99qkWfZ9p7AYANEJ9+/bVoUOHNH/+fLcA9swzz9TUqVPl7e2tFStWOEPX6vj6+nocJRwQEKA2bdq4tbm+jsVi0S233KLCwkJdd911CgoK0j333KOkpKQ6jRgGAOBU8tdff2ns2LFauHCh7rrrLr300ktVSjBUh3AXAADUy6PJ25VbWKzcwmI9mrzd2A0AMJm2bdvq3HPP1S+//FLrxGZnn322WrdubWy23ybqVfM/N4cMGaKNGzdqwIABOnTokN58801FRkbqjDPO0EsvvUTICwA4LeTm5mro0KF6//33dffdd+u1115z+0K0RYsW8vb2VllZmdt2ckxOWvP/bQEAAGrh5VX5jbLrYwBA4/bXX395DFCLi4tVVlYmb2/vOo8aOla9e/fWqlWrVFhYqO+++07/+te/5O3trUceeaROZSEAADCz/fv3q3///kpMTNRTTz2l2bNnV7kj5qyzzlKzZs30v//9z61djglSCXcBAEC9vBjRVeee4a9zz/DXixFdjd0AgEbo/vvvV2BgoD7//HNjlzIyMvTrr7/q8ssv1xlnnGHsPi5+++03hYWF6frrr1dhYaF8fHzUu3dvvfrqq/rggw8kiQnVAACntL179+raa6/Vpk2b9MILL2jKlCluI3YrtG7dWhdffLF27typgwcPuvX99NNPhLsAAKB+bu/SVtnjByt7/GDd3qWtsRsA0AgNGDBAFotFs2bN0p9//ulsr5ilOyAgQEOHDnXb5ngKCgpS27Zt9d///lfbtm1z66uY7K158+Zu7QAAnCpKSkr08MMP64cfftALL7yghx9+2Dk5qdGZZ56pG2+8UZs3b1ZiYqKzvaSkRLNmzSLcBQAAAIDTTWRkpMaMGaO1a9cqIiJCCxcu1KxZs9StWzd99dVXevrpp9WzZ0/jZseNt7e3HnzwQfn6+ioqKkrTp0/XsmXLdP/992vs2LHq3LmzbrzxRuNmAACcEr755hslJCTI19dX27Zt0913362xY8e6LS+99JIKCwslSffcc4/69eun6OhoTZgwQUuWLFH//v310UcfEe4CAAAAwOnGx8dH8+bN07///W8VFBRozJgx+te//qXWrVtr/fr1euihhxq83u5ll12mb775Rt27d9dTTz2lm266SYsXL9aDDz6olJQUnXPOOcZNAAA4JWzcuFHl5eUqLi5WfHy83nrrrSrLqlWrVFpaKjlG73744Ye666679Oabb+rWW2/V//73P8XHx8tis9lsxhcAAAAAADROffv2VXJystatW6eIiAhjNwAAaCTuvPNOvfPOO3r77bd1xx13GLuPC0buAgCAelmwLVMtXv1ELV79RAu2ZRq7AQAAAAANhHAXAADUy2Prv1deUYnyikr02Prvjd0AAAAAgAZCuAsAAOrF6lKT0fUxAAAAAKBhEe4CAIB6mdnvUgWf4a/gM/w1s9+lxm4AAAAAQAMh3AUAAPUysvOFOjB+sA6MH6yRnS80dgMAAAAAGgjhLgAAAAAAAACYEOEuAAAAAAAAAJgQ4S4AAAAAAAAAmBDhLgAAqJf5W3frzFc/0ZmvfqL5W3cbuwEAAAAADYRwFwAA1Mvj67frUFGJDhWV6PH1243dAAAAAIAGQrgLAADqxcda+eeE62MAAAAAQMPiX2AAAKBe4vpdqvObBuj8pgGK63epsRsAAAAA0EAIdwEAQL3EdLpA++4bpH33DVJMpwuM3QAAAACABkK4CwAAAAAAAAAmRLgLAAAAAAAAACZEuAsAAAAAAAAAJkS4CwAA6uWNLT8p8OXlCnx5ud7Y8pOxGwAAAADQQAh3AQBAvTzx1Q79VVKmv0rK9MRXO4zdAAAAAIAGQrgLAADqxder8s8J18cAAAAAgIbFv8AAAEC9vBzZXRc0a6ILmjXRy5Hdjd0AAAAAgAZCuAsAAOplRGgb7b13oPbeO1AjQtsYuwEAAAAADYRwFwAAAAAAAABMiHAXAAAAAAAAAEyIcBcAAMCDadOmadq0acZmnGB8Do0HnwXgjp+JxoPPovHgs2gc+BxOLxabzWYzNgIAANTVvC0/KTZpmyQprt+lurfHRcZVTMlisUiS+FPp5OJzaDz4LBqPvn37Kjk5WevWrVNERISxGycIPxONB59F48Fn0TjwOTQed955p9555x29/fbbuuOOO4zdxwUjdwEAQL088eUOFZSWqaC0TE98ucPYDQAAAABoIIS7AACgXvy9K/+ccH0MAAAAAGhY/AsMAADUy8v9LlVI8zMU0vwMvdzvUmM3AAAAAKCBEO4CAIB6+WfoBdp99wDtvnuA/hl6gbEbAAAAANBACHcBAAAAAAAAwIQIdwEAAAAAAADAhAh3AQBAvZXbbCq32YzNAAAAAIAGRLgLAADqZW7aj2oS97GaxH2suWk/GrsBAAAAAA2EcBcAANTL5C93qKisTEVlZZr85Q5jNwAAAACggRDuAgCAemni4+3xMQAAAACgYVlsNgrkAQCAY/fhrn2atN4+YnfG1V2U/u9FxlVMaerUqW7/xcnB59B48Fk0Hu+884727Nmj22+/XSEhIcZunCD8TDQefBaNB59F48Dn0HgkJycrOTlZERERioiIMHYfF4S7AADguDp48KCxCcBpoGDLu/ILvUFe/mfqzPfPlFeel7KGZ+p722e6LPg24+qop8OHDysgIEDe3twxAQBAY/bXX3/Jx8dHvr6+xq7jgrIMAACgXn4vKNJj67frsfXb9XtBkbEbAAAAANBACHcBAEC93PTJBr2wYZde2LBLN32ywdgNAAAAAGgghLsAAKBe1u/7zeNjAAAAAEDDItwFAAD1EtPpQo+PAQAAAAANi3AXAADUy+JBl+mDIX30wZA+WjzoMmM3AAAAAKCBEO4CAIB6uyW0jW4JbWNsBgAAAAA0IMJdAAAAAAAAADAhwl0AAAAAAAAAMCHCXQAAUC+//lWoieu2aeK6bfr1r0JjNwAAAACggRDuAgCAern50w2K+zZDcd9m6OZPNxi7AQAAAAANhHAXAADUy1f7cjw+BgAAAAA0LMJdAABQLyO7XOjxMQAAAACgYRHuAgCAenl34GX66IYwfXRDmN4deJmxGwAAAADQQAh3AQBAvQ2/5HwNv+R8YzMAAAAAoAER7gIAAAAAAACACRHuAgAAAAAAAIAJEe4CAIB6OZBfoAfXbtWDa7fqQH6BsRsAAAAA0EAIdwEAQL3c8ukGvbrp//Tqpv/TLZ9uMHYDAAAAABoI4S4AAKiXb/b/7vExAAAAAKBhEe4CAIB6ub1rW4+PAQAAAAANi3AXAADUy9sD/qaPh16hj4deobcH/M3YDeB04WV1PixXuf5s+acOKd9tFQAAABxfFpvNZjM2AgAAHKuDBw8amwCcLmzlksU+fiTx0Ofq5ttdrfyDjWvhODh8+LACAgLk7e1t7AIAAI3IX3/9JR8fH/n6+hq7jgtG7gIAAAA4PhzBbgUva+VoXgAAABx/hLsAAAAAAAAAYEKEuwAAoF72Hy7Q/f/dqvv/u1X7DxcYuwEAAAAADYRwFwAA1Mstn6Zq9ub/0+zN/6dbPk01dgMAAAAAGgjhLgAAqJcNB3I9PgYAAAAANCzCXQAAUC93dG3r8TEAAAAAoGFZbDabzdgIAABwNFb8eECSNLh9ax08eNDYDeA0lHjoc3UP6KVzfM4xduE4OHz4sAICAuTt7W3sAgAAjchff/0lHx8f+fr6GruOC0buAgCAehvcvrUGt29tbAYAAAAANCDCXQAAAAAAAAAwIcJdAAAAAAAAADAhwl0AAFAve/88ovsS03RfYpr2/nnE2A0AAAAAaCCEuwAAoF5GrNigeVt+0rwtP2nEig3GbgAAAABAAyHcBQAA9bLxQK7HxwAAAACAhkW4CwAA6mVUt7YeHwMAAAAAGpbFZrPZjI0AAABHY/VPWZKkARcF6+DBg8ZuAKehxEOfq3tAL53jc46xC8fB4cOHFRAQIG9vb2MXAABoRP766y/5+PjI19fX2HVcMHIXAADU24CLgjXgomBjMwAAAACgARHuAgAAAAAAAIAJEe4CAAAAAAAAgAkR7gIAgHrZnZevu9ds1t1rNmt3Xr6xGwAAAADQQAh3AQBAvUSv2Kj5W3dr/tbdil6x0dgNAAAAAGgghLsAAKBevsvK9fgYAAAAANCwCHcBAEC9jOnWzuNjAAAAAEDDsthsNpuxEQAA4GgkZv4qSYoKOVcHDx40dgM4DSUe+lzdA3rpHJ9zjF04Dg4fPqyAgAB5e3sbuwAAQCPy119/ycfHR76+vsau44KRuwAAoN6iQs5VVMi5xmYAAAAAQAMi3AUAAAAAAAAAEyLcBQAAAAAAAAATItwFAAD18uPBfI1e/Z1Gr/5OPx7MN3YDAAAAABoI4S4AAKiX6BUbtGj7Hi3avuf/2bvv8Ciqr4Hj393spjdCDb2EJh2kdwHpRREEf1SpigFREGlSBJRmAfUFESkivXekSe9NeidACklI75vsvH8kGbKTDSSEoOj5+Mzjcu+d2Z2Zu5vdM3fO5b0tx7XVQgghhBBCCCFyiAR3hRBCCJEtZwPCrD4WQgghhBBCCJGzJLgrhBBCiGwZUKWE1cdCCCGEEEIIIXKWTlEURVsohBBCCJEVe30CAWhWLB+hoaHaaiHEf9AfETuo6lCDfMZ82irxAkRGRuLg4IDBYNBWCSGEEOIfJDo6GqPRiK2trbbqhZCRu0IIIYTItmbF8tGsmARwhBBCCCGEEOJlkuCuEEIIIYQQQgghhBBCvIIkuCuEEEIIIYQQQgghhBCvIAnuCiGEECJbrodE0mfbKfpsO8X1kEhttRBCCCGEEEKIHCLBXSGEEEJkS48tJ1ly6R5LLt2jx5aT2mohhBBCCCGEEDlEgrtCCCGEyJZzj0KtPhZCCCGEEEIIkbMkuCuEEEKIbBlcrZTVx0KI/zazYtYWCSGEEEKIF0ynKIqiLRRCCCGEyIoDD4IAaFwkLyaTSVsthPiPSVKSCIx/hKJAQYeC2mrxAgQGBuLq6oq9vb22SgghhBD/ICEhIdjZ2eHk5KSteiEkuCuEEEIIIYQQrxh/f3/c3d1xcHDQVgkhhBDiHyQoKAh7e3tcXFy0VS+EpGUQQgghhBBCCCGEEEKIV5AEd4UQQgghhBBCCCGEEOIVJMFdIYQQQmTLleAIem49Sc+tJ7kSHKGtFkIIIYQQQgiRQyS4K4QQQohs6bn1JMsu+7Dssg89t57UVgshhBBCCCGEyCES3BVCCCFEtvwVFG71sRBCCCGEEEKInCXBXSGEEEJkywfVSlp9LIQQQgghhBAiZ+kURVG0hUIIIYQQWXH4YTAADQrn0VYJIV5RCaYobI3O2mISEiKxtXXRFouXzN/fH3d3dxwcHLRVQgghxD9KYmIsBsN/9+9VUFAQ9vb2uLjkzPcnCe4KIYQQQggh0nkYcJR7D3Zri7GzdSM+QVKwCCGEECJzKpbtibvrf/cOv5wO7kpaBiGEEEIIIUQ6RoOjtggAGxtbbZEQQgghhPibSHBXCCGEEEIIIYQQQgghXkES3BVCCCFEtlwKjqD75uN033ycS8ER2mohhBBCCCGEEDlEgrtCCCGEyJaeW06w8uoDVl59QM8tJ7TVQgghhBBCCCFyiAR3hRBCCJEtl4KfTKyU9rEQQgghhBBCiJylUxRF0RYKIYQQQmTWx3vP8/3pmwAMe7003zWrqm0ihHgFPQo+z827m7TFODrkJSY2SFss/kN+X3qYeT/s0RZnKH8BN+b/2h+AQe//gmfBXMz45j0cHF/dyfkeB0cy6P1fAJj/a39y58mZGdCzIzYmgc8+Wc61K7788HNfypYrCGnO3+CPmvO/Xg20q6WT2v7r2d2p37CstvqZFEXh2JGbBAdF0uGtGtrqf63MvE9cXByoVacUffo3pniJvNrqHDV14gb+3HfFom8IkVMqlu2Ju2tJbfF/RlBQEPb29ri45MzfChm5K4QQQohs+a5ZVY72eIOjPd6QwK4QQvwHODnbky+/m8Xi4uIAgF6vI29eV4u6/AXc0NvIT8//qhPHbjHqk+VERsZqq/4TCni6075TjXTLG80roKCwd/clPui/kKuXfbWrCiFEprxSf2FjY2OZMWMGe/ZkfPXrxo0btGvXDltbW3Q6HTqdDi8vL65cuaJtmiWKojBjxgx1m2XKlMHHx0fbzCqTycTAgQPVdWfMmEHqgOkzZ87g5OSETqejd+/e2lWz5dixY9jZ2anPu3LlSm2TDPXu3Vtd71mLh4cHderUYeHChUREWJ9IJ3V7Tk5OnDlzxqIu7THIzGJjY0OJEiUYPHgwV65cUY+lNYqisG/fPiZNmqStytCDBw+YMGEClStXxsbGxmI/W7VqxZYtWzCZTNrVXhmPHz/m+++/p06dOhbvE2dnZ+rUqcOyZcuIjo7Wrqby9/enWLFi6Kz02Zzszy9aREQE48aN48KFC9qqdH7//Xd0Oh1t27YlLi7Ook5RFK5cuULfvn3x8PBQj2ehQoUYNWoUjx8/tmifKivvsbRLZo+ryWSib9++6HQ6tmzZoq226vHjx8yYMQMvLy+L56xYsSKzZ8/O8P2dlr+/P59++imFChVS1/fw8KBv375cv35d29wqa9vInz8/gwcPzvQ2oqOjWbx4cbr3cfHixRk3bhz+/v7aVdKJjY1l6dKlFtuwsbGhcuXKLF26lNjYZ/9ASUpKYtu2bTRo0MBiGw0aNGDbtm0kJSVpV8mU8PBwmjdvjk6no0mTJkRFRWmbqE6dOoWjo6PFOc1oKVasmHb1TKtbKDd1C+XWFgshhPgX6vT266zbMtxiGTuxEwCVqxbj9zUfWdT9+PP75MrlpN2MeEX8r1cDDp2c+FyjdgGSkszaov+UqtWK8dmY9umWSdO6sHH7p7TrUJ2oyDh+mb+f+PiX9ztz7MS32H1wrIzaFeJf4JUJ7l69epXq1aszatSoDH9Qnzt3jtq1a7Nt2zaL4JudnR358uWzaJtVOp0Ob29v2rVrB8DNmzeZOHHiM4N8iqLw7bffsmDBAgC6dOnC8OHD0el02qYvlKIo/P777yQkJKhlixcvzvDYZUdoaCgnTpygf//+FClShB07dmibvFBms5l79+4xf/58KlSowMiRI62eh+joaN5//32aNWvGnTt3tNXp3LhxgzfffJOiRYsyefJkLl68iNn85ItIaGgou3btokOHDpQvX559+/Y9NbD8TxMQEEDfvn3JkycPH3/8MSdOnLA4btHR0Zw4cYKePXtSsGBBVqxY8dyBp3+6ffv2UapUKb799lsSExO11RYURVEvKDVq1Ah7e3u1LjY2lk8//ZQKFSqwePFiQkND1To/Pz9mzJhBuXLl2Ldvn1r+MqR+7ixevFhblaFt27ZRtGhRRo0axe3bty3qLl++zIgRI576/lYUhaVLl1KsWDG++eYb/Pz81LrQ0FAWL15MuXLlGDVqlNX3K8/YRmBgIPPnz6dcuXJMmTIlw20AnD9/nvLly9O3b99072MfHx+mTp1KsWLF+OmnnzJ8Dz98+JBGjRrRu3dvi22YzWYuXrxI7969qV69OlevXtWuqgoPD6dbt260a9eOI0eOWGzjyJEjtGvXjnbt2hEenrUctYqiMGvWLPbu3autsurGjRs58tkvhBBCCCGyx87OSLcedXF1deDaFV8C/MO0TYQQ4plemeDu5s2buXbtmrbYwqJFiwgLS/4w7Nq1Kzdu3MDPz4/Nmzfj7u6ubZ5lDg4OzJ07l5Ilk/OELF68mFWrVmmbWTh16hRfffUVAKVLl2bmzJkYjUZtsxcuMDBQ/eFfvHhxAPbv38/58+c1LZ+tRYsWDBgwIMOlVq1a6PXJXSkiIoL33nuPkydPajeTKe+88w5+fn4ZLg8ePGDz5s20aNFCXWf27NnMmTPHYjsA165dY/Xq1dpiq7Zt20a1atXYvXs3AKVKlWLq1KlcvHgRPz8/7t27x2+//UbNmjUBuH37Ni1atODHH3/MMDj0T3L+/Hlq1aqlBvvy5s3LZ599xokTJyyOa6tWrSDNeXxaIO5VtmTJEoKDg7XFVj1+/JiTJ09iMBioX7++Wm4ymRg2bBjffvstAEWLFmXRokU8ePCA8+fP07NnTwCCg4N59913uXz5srouKYFi7XvJ2vLee++peXn0ej0dOnSw2I5WasB51KhR2qoM7d69mw4dOhATEwNA//791b5/5MgRdV+e9v5ev349ffv2xWQyYTQa+eKLL7hx4wYPHjxgxYoVlCpVCoAZM2bw7bffWn3fpN0GQIcOHThy5Ah+fn7s2LFDff+NHz+eIUOGWO2bly9fpkWLFjx48AA027h48SLe3t4YjUZMJhPe3t6sX79euwnCw8Pp06cPp0+fBqBmzZrs2LEDPz8/Dh48qH7+XLt2jbfffpuHDx9qtpDcP0aOHMnatWshZRsHDx5Mty87d+5kwIABVvclI1u2bGHatGna4gyl3rlia2vLe++9l66PpV26d++uXV0IIYTIEYGB4Uyfupmm9b6kYa2JvNV2NutWnyQx8cnggsfBkbzT4VumTNzA3t2XaNX0KxrWmsjwj5YSEpJ814rZbObYkRsMfv8XGtaaSMNaE+nS8Tv27LK8wJsqPt7EutUn6Pb2HLV903pf8tnw5dy7mz6XtKIonD19l17dfqJhrYk0rjOJOd/sJDb2ySCaVKmvd+rEDTy4/5gRw5apzzH4/V+4cM7H6negoMAIpk/dTIvG0575ekizD106fqduv1e3n9i357LVfc6u35cepmGtiRw59OQOKrPZzL49l9Xj0rDWRFo0nsa0yRvxfRgCKTl/vQcv5vNPVwAw74c9NKw1kd+XHla3Ex9vYuf2C/yvyw/qdt7vMY+zp++mO1a/Lz1Mi0ZTuXrZlz27Lqr737jOJD4b/jsP7qe/Yy4r/eP6NT9aNJrKkl8PsuK3IzSt9yWN60xi2uSNxFk53y+Ks7M9Do62JCQkEhdn+Z3wcXAkP3y3i9bNvlb7xsRxawkKtH5HXXR0PP83d7favnWzr1nx2xEuXrhPi0ZTmTpxg9p26sQNtGg0levXngyoIIv9a+rEDbzT4Vv8fENZ8dsROrScqb7O6VM3E/I447vLhBAvzisT3H2W6Oho9QdsoUKFmD17NqVLl8bT05NSpUphMBi0qzyX4sWLM2PGDDWY6e3tzblz57TNIGWEWI8ePQgLC8PW1paFCxdm65bXrNi/fz/Xrl3DYDDg7e2No6MjCQkJrFiR/Ic1K7y9vfn5558zXE6cOIGPjw8NGzYEICwsjC+//NJipNiSJUtQFIXo6Ghq1Mg4ib6joyOenp4ZLoULF6Z9+/bs2rWL6dOnq+vNmzdPDeZk1bp169TAltFoZOHChVy/fp0xY8ZQsWJFPD09KVasGD169ODEiROsXr0ao9GI2Wxm2LBhT00T8k9w8uRJmjZtqh6fSZMm4ePjw/Tp06lVq5bFcd2xYwdHjx4lT548kBI4X7hwoWaLGatRowbR0dEoisKSJUu01a+k8+fPc/XqVcqUKUO5cuXU8i1btvDLL8mTaDRt2pRz587Rp08fChcuTJUqVViyZInaR4ODg/m///s/iy+o/fr1S/de0i7z58+nYsWKREZGAvDVV1/x9ttvq9vQunLlCo0bN1YDzpkRHh7O+PHjMZvN6PV61q5dy4IFC9S+X69ePZYsWcLKlSvR6/WEhYUxdOhQi9GmgYGBTJgwAbPZjLu7O4cPH2bSpEmULl2awoUL061bN06ePEmDBskTdnz11Vfpgt23bt3C29tbfR2LFi1i48aN1KtXD09PT1q1asWRI0eYMGECAAsXLkw3ijgxMZEpU6aogfvp06dbbKNixYrMmTOHP//8E3d3d/U9fPfuXYvt/Pbbb+rFsQEDBnDkyBFatWqFp6cnDRs2ZNu2bQwYMABSAryp/SCttP0jdRsNGzZU9+XAgQP069cPgLVr12Y6dYaPjw8jRoxI98U6I3FxcepFvcqVK/PDDz+k62dpl6+//lq7iUy5EBhG103H6LrpGBcCZcSJEEKIp7t2xZe+Pebz574rNGpSjgaNyhLyOIrvZm3npzm70wX1Duy7wpQJGyhd1pM3mlegYKFcuLk5kpiYxP/N3cNnw5dz9YovDRqVpVWbKsTFmZg0fh2zvt5mESyOjopj1Ccr+G5W8neIth2q0bZDNQoWysWxIzf4oP9Cbl5/krpJURRWLDvKsA+X4HMviAaNytLkjdfYvOEMI4b9TnhY8oVxrfv3H/PZ8OWcP+vDG80rULd+GS5feshHgxaxavkxi/07ffIOvbr/xNZNZylSxIP2nWpQsXIRjh25Qe/uP3Fgn2V6wZDHUXzqvYzvZu0gLs5EqzZVeKN5Bfx8Q5kwZk26fc4JiqKw8vdjTBizhtDQaFq1qULbDtXIl8+VHVvPM+j9X7h7JxAbg5569UtTs3byRf6y5QrSvlMNSnnlh5TzMWXiBqZO3MCjgHDeaF6BN5pX4L7PY4Z9uIQVy46m6wuKkhwknvzFegp4utO2QzU8C+bi2JGbfNB/IXfvBKpts9o/Uq36/RjzftxLnXpe1GtQhiJFcmPvYKsGudMGSEkT/NaWZ9atm48IfBRBAU93ChR4Mijt5o0ABvb9hVXLj+Hm5kjbDtWoWLkIe/+4RK/uP6XL0RvgH8YH/Ray/LcjantPT3d+mrub6VO3YDKl31et5+lf0dHxTJu8kXk/7qFCpcK0alMFZxd7tm46y8dDlhIamnHKPyHEi/GvCe4qiqLeXl2iRAnc3Ny0TV6Yjh07MnDgQEgJZI4cOTLdbbUmk4mJEydy82by7OFffvmlGtjIaXFxcfz2228AlClThnfeeUcNvO7evZtHjx5p1si+woUL89NPP6kjpI8ePcq9e/e0zV6Y1DQZLVu2hJTA0KVLl7TNnunhw4eMGzdODSitWLGC999/HxsbG21TSHneLl26qCOFzWYzY8eOzTCv6t8tKiqKcePGqSPap0+fzvjx43FwSJ7wwpq6deuyfPly9QLG5MmTuXXrlrbZf8b+/ftRFIVmzZqpQe/w8HA1d7anpyc///wzHh4eFuvpdDoGDRpE7dq1Adi6datFmoHMWL9+PePGjYNnpHTx9/fngw8+oEKFCpw6dUpb/VSnTp1SR+J+8MEHVoPHOp2Ozp0707VrV0jJrXzx4kW1/vjx42qwdtiwYdSqVUutS+Xh4cHgwYMh5XPzjz/+sKhft26dmgfX29tbzUmcltFoxNvbm5o1a2I2m5k2bZrFZ+/NmzfV7bZs2RJvb+902wCoV68ew4YNA8DX15ejR4+qdYmJifz5558AuLu7M3To0HR3WxiNRkaPHk2hQoUA2Lt3rxqAJ2X09M8//6z2j88++yzdNhwcHJgwYQJeXl4oisLPP//8zNQJqaOBb968SalSpXB0dNQ2SSc8PFy966VixYov5C4Wa3pvO8Waaw9Zc+0hvbdlrQ8KIYT474mLM9GocTk2bP2ESdO68NWs7vzfL/2wtzfy594r6W5Nj4sz8f7AJsz5v95MmtaFkaPbY2Oj58ihG6xafoyChXKxdOWHfDWrO2MnvsXKdd7UqFmSLRvPsHXzk4E4O7f/xZlTd2jfqQbLVg/h83Ed+XxcR5atHsLAD5sRFRnH/jTB1BvX/Vm04E/yF3Bj8fIP+GpWdyZN68L6bZ+QP79bulGWqa5cekju3M6s3jiMSdO6MOPb9/i/X/rh7GLPsiWH1RG5QYERfDdrOzHR8Uyc+g4LfxvEZ2PaM3deH779oReOTnbM/HqrGrBMSjKzaMGfXDjvQ7sO1Vm7+WPGTnyLSdO6sHrjMKpULZZun3OCv18Yq5Yfo/xrhVi2eghjJ76lHsfe7zfCnKRw+dJDbG0NdO9Zn85dk78bNm3+Gp+NaU+deqUB2Lj+NH/uvUKVqsXUYzVpWheWrR5C4SK5mf/jHk6ftEyvFx9v4sYNf35a8D5z5/VRn7d1u6qEh8Vw9PANtW1W+0eqqKg4vvjybb6a1Z2vZnWnZ9/k39IvWmxMAvv2XObLCetRFIW2Harh5p78/S42JoE53+wk8FE4fQc0Ufvr3Hl9mPJ1V2Ki4/nqy01q4DQpycyShQe5eyfQon8v/G0QE6e+w4P7wc/Mffy8/SsqMo7HwVH8vsbb4hhXqVqMu3cCuXwx/V1uQogX618T3E0rddKanGIwGJgyZYp6W+3evXstcjcqmnyXTwvK5IRbt26pwYpatWpRpEgRNVfwtWvX2L9/v2aNF8PLy4t69eoBEBISkuMBQQcHBxo1aqT++6+//rKoz4xffvlFDXx4e3tbDWxZ0717dzVod+rUKY4cOaJt8o+wdetWNdVEu3btMgx2aTVt2lQN5Pn7+7Nu3Tptk/+EqKgo9dw2b95cLb9y5Yo6Yn/gwIF4eXmpdWm5ubnRrFkz9Ho9zs7OhIQk36KWGT4+PowePRqz2YynpyfTpk1LFyAk5TV2796defPmqWXNmjWz+PfTnDx5EkVRMBgMdOvWLcP+YTAY1M+RxMREi/zC8fHxNGvWjAIFClgcJ61y5cqpAcm0aQgSExPVoLSjoyM9e/bM8HXkzp2bTp2SJ2y5ePGixefM1atX1WPcs2fPp17EaN68uXpHR9q0CnFxcerIX1dXV3Lntj5BWN68edXzfv/+fYsJze7du6fuT7169dTUOFpFihShdevWAJw4cSJdruO0Uv+urFmzBnd3d2bNmqVebHia+/fvq0HzOnXqZHhcs+tK8JNbA9M+FkIIIaxxdrGn1/uNsHewVcu8yhSgWo3iREbGEhFhecHT1dWBRk3KWfwdi483sXnDGRRFYeAHzShWPK9a5+Rsz9BPWuHsYs+OreeJjoojLs7E5YsP8MjtTMe3a2AwPBnModPpKFkqeY6WoEdP/o4d2H+VuDgT3XvUp0TJJ3O4uLo6MODDNzAarQ8IMRptGOzdHI/czmpZxcpF6NKtjkUA8ujhG/jcC6ZNu2o0eeM1i/17vVZJtf3uXckX1R8+eMy+PZcpXMSDvgMaY2f35LuhR25nBns3x2i0Ydums0RGPv2icXZERsYSEx2Pi6sD9vZPXoNOp6P/4DfYvncU7TpUt1hHKyQkih1bz1s9VgU83fl4RGsUBbZtOZcuKNmwcTkqVCqs/ttgsKFx0/IA3LuTHDjPav9Iq1jxPLxeMzkVY1qpE8uNnfhWpspT7dx+QU1vkHZ5s8k0JoxZQ1ysiY9HtKFr9zrqOn9duM+Fcz68VrEw73avY9FfGzUtT6fONbl7J5Czp5LvQHsUEM7RIzcoVjwP/Qc3VdvrdDqavPEabdpVU9fPSHb6V7uO1Shc5MlgFydne+o2SA7i+9yznl5ECPHi/OODu9OnT0en0/H555+rZR06dECXMqv3rl27cHJywsXFhQMHDgBw4MABXFxc1DaZmRU9q3Lnzs0PP/ygjoKaOHEihw8n5w46fPgw48ePh5ecZzfV2rVrCQsLQ6fT8e6776LT6WjWrJk6sjCnJlazs7PLcFRY6ig8Jycnzpw5o61+oc6cOYOTkxOvv/66mkN06dKl6FJmg0+9/Tk4OJg1a9ZAJgJKWm5ubvTq1YtixYoxdOhQChQooNalPn/qc8XGxjJ37lyKFy+uvob8+fPz6aefZtg3/f39KVasGDqdjt69e2urLaQeW21fTzuCm5Rbw58W7ErLYDDQu3dvihQpok4clZnJ1dLu+9Ned0REBLNnz6Zy5crqxRgbGxtq167N8uXLM+yf2mOblJTEtm3baNCggbodW1tbWrVqxcGDB9PdxpV6rJYuXQpATEwMr7/+OjqdjiZNmlgE6AAePHjAxYsXKVCgABUrVlTLjx07RkJCAgaD4anBTICpU6eSlJTEpUuXqFSpkrbaqsTERL7++mt15P8XX3yRYQA5LVdXV5YvX86uXbsoWDBzs94GBwdTqFAh3N3dn9k/tMcnVZcuXdizZw/+/v5PvUPh2rVr6nsy7Wdi2oBqnjx5nvnaU49jTEyMRS72wMBAihQpgtFoxNXVNc0a6cXFxVnt0waDASen5Nm8ExISMuyLcXFxBAUlf1HNly+fxUR7t27dUoPMLVq0eGpaoCZNmkDKaOYLFy5oq1Vp87d/9dVX6sWlZ0mdTM1gMFChQgVt9QvzUY0n/TPtYyGEEMIaV1cH3HNZ3oFia2vAzc0RkymJ+DjLCW89cjvj7p789znV4+Aobt0MIG9eVypXLWpRB1DA042SpfJx704Q/v5h2Nsb+eLLzmzaMYLSZQoQGhrNuTP32LD2FKNHrmTiOMvBDHGxCVy55IudnZHXKibfrZNWkSK5KVTY8s6tVOUrFFKDxWm9XrMkNjZ6rl7xIynJzLUryXd11alfGhub9D/N69YvjZ2dkSuXfImLTcDfL4yIiFjKv1aIvPnSf9cpWSof5SsU4r7PYx4FZG3C1qwo4OlOocIenDx+iwG9F7Bp/WkC/MPSffd+mof3Q3j4ICTDY1WyVD7y5HHh6mVfwsMt01+U8sqf7ndbnrwuFoHmrPaPtDwLuuPg+OTCQ3YV8HSnfaca6lK6TPJvR1dXB6Z83ZWd+z+nc9cn89gAXDifnJ+5QcOyODk/+Z5JSsC2+uslICWtB8B9n2BCHkdR/rVCeHg8CZQD2NjoqVM/OdD6NNnpX6mpNtIqXuJJQF0IkbPUT4/Lly+TN29eNfiU0VK6dOnnzm36b1OrVi31lumEhAQ+/fRTLly4wMCBA0lISHjpeXbRBCzLly/P66+/DilB5jfffBOyMbHas4SFhXHjRvJVaAcHB4uAZ06IjY3l4MGD6r8rV65sUf8s165dU19vzZo1KVu2rLbJU3344Yfcu3eP77//3upt6KQck7fffpuhQ4fi4+OjlgcGBvLNN9/g5eXFtm3bLNZ5Ufz8/Dh79iwAxYoVe2quY2tatWrF/fv3Wbx4MS1atMgwVUVW7du3j1KlSjFixAguXnwykYHZbObkyZP873//o3r16ly9elW7qoWwsDC6detGu3btOHLkiLodk8nErl27aNy4MR9++GGWJqrSOnjwICEhIVSvXl0NOCqKol6gKFq0KKVLP/uLUlYdO3aMX3/9FYDGjRvTo0cPbRML+fPnZ86cOQQEBNC9e/csnatvvvmGhw8fEhQU9NQ+kjZdgU6ne2YgWCskJEQdTezu7q5+HmVX2vfV4MGDuX//PgkJCbRv396indbhw4fVHyBpL0rZ29uro2kDAgJYtWqV1R8qO3fuVPO8169f32IbqeWk9JGnKVKkiHos0+5LWo8fP+ajjz4iLCyMPn36qLl6M+PEiRMA5MqVi7Nnz9KqVSs8PDzUv+n58+dnyJAh6fIOZ9U3b1ThZK9mnOzVjG/eqKKtFkIIISzky++Gg7314FlSkjndqED3XE7Y2lpeLE0dPRobl8CC/9vHjGlbLJYfvvuDwEcRxMTEq4GoxMQkVv5+lDcbf0WHljMZ+sFivpmxjeNHbuKecjt8KkVJfi1Ozna4uaVPhWRra8A9l2XAOVX+/G44pBmVnMrO3oDRaEN4WEzyreyPI7GzM5Ivf/pAGoCbmyNOznZER8VhSkxS96NEyXzpgpsAdnZGcnk4Ex9vIiY65yYAc3Nz5PPxHcmT14W7dwKZ9fVWunT8jjbNpzPr663c93n2xMWRkbEkJZkJ8A9n7re70p2/BfP2kWBKJORxVLrJw9KOEM3I8/SPVG5ujun6W3ZUrVaMz8a0V5eFvw1izBediIqK4+upm7l5PUC7ijqC/OSJ2+le+4xpW9i1/QI6nY6HD0OIjUnA517yMS9e0npANV9+V4uRuNY8b/+ytzeSy8P6e0EI8XKowd3XXntNzYf4NJ9++ilFihTRFueYIUOG4OfnpwZRSZmcy8/Pj1OnTtG4cWPu3LnDzZs3qVMn+TaGOnXqcPPmTbVNvnzprwS+KEOHDqVLly6QMrLqjTfeUEeSvcw8u6lOnz6tBsW6dOmi3rZrMBjUWdATEhL4/fffrQYsnpeiKGzatEmdXb5EiRKUKJF8NTGnbN++XU05UK5cOapXT771p3Llyty5c4edO3eqQZN33nkHPz8//Pz81JnuL1++rOZp9vLywtnZ8grni/Dxxx+zc+dOChQowKJFi9SRoP3794eUkYcdOnRQ9+NFunPnjppf2cvLK8NR1S/Tli1baNWqlTpKs2fPnhw5cgQ/Pz+OHDlCz549ISXw3rFjxwyDXaQc27Vr11K1alVWr17NgwcPuHHjBsOHD1dHhc6fP99ioqrvv/8ePz8/3nnnHUi5CLFz5078/PxYt26dOmKTlGBm6nnp0KGDOjIzOjoaX9/kyQuKFCmCk5MTERERzJgxAy8vL4ug2dNGZ2ckNjaWGTNmkJCQgF6vZ+zYsU/tm87OzqxatQpvb+8sB1yz4syZM2zevBlSLhxVrVpV28Sqx48fs3jxYqpVq6be3TB69GiLUaSZHS2bKnVEM8D1609mbc6se/fuqaPaPTw8LNK7kNIvmzVrBsCYMWMYPXq0+l4KDw/n66+/5v3334eUC2fatDupr8/R0fGZf3/0er26rrV9MZlMjB49mlOnTlG6dGkmTpyY6TtB0k40GhQUxLBhw9i1a5dFSo3AwEB++uknypYta5Fe6HnU9PSgpuezf2wJIYQQL1JUZBw7tp1ny8Yz6Za0uXsVReGnObv58fs/cM/lyMAPmvHzogFs3f0Z+46M4+MRyRd3XwQbmyd/318EWzuD1ZG9GTEYbDDaZv5i//MoV74gazZ9zE8L3qd126o4u9gTFRnHpvWn6fnuj+kmgstI4KNwtm46m+7c7dh6PsMJ67Iis/3jZdLpdLRqW4WefRoSFRnH5C/WZ/hazp+9l+51b9l4hkMHrll8b9NOcpaTXkb/EkJkjfoXQqfTPTV3JCmjG1MDmS+Ls7Mznp6eFsGNXLly4enpqd4Kmz9/fgoUKICdnR2kpAcoUKCA2iYro9iyymg0MnPmTHX0XuqtuC87zy4pwaglS5agKAq2trbqZGOpateuTbly5QDYuHHjC5nwLC4ujkuXLjFw4ED69u2rlnt7e2cqH2RWpT7fgAEDeOedd9SJ0KZMmUL+/Mm3ghiNRvLnz0+ePHnU4+/o6Iinpyeenp5qkC7tJGg5MfqSlP7QsGFDzp07R58+fShcuDAVK1bk559/ZuXKlej1esxmM+PHj083KV92RUREqH/wCxUqZBG4/DsEBgYyevRoTCYTer2elStXsmTJEurVq4enpyf16tVjyZIl6nG5efMmI0eOzHDkbUhICP369ePo0aN06dKFwoULU7p0ab755ht+++03dDodiqKwfPlyNYjv7u6Op6enmvdVp9ORJ08ePD09yZ07t8X79dGjR5w+fRpHR0d1BDya2/Hz5MnDX3/9RcWKFRk1apRFztTnHZ194MABtX2bNm1e+gUia8LDwxk7dqyaUiEz7+/Dhw9jNBrJkycPffv25f79++TNm5eVK1cycuRIi2Ntb2+v5jAPCAhIN9laWlFRUWqQ+XmYTCamTZvGnTvJt7D17NlT/VxM5ebmxvr16/H29kav1zN9+nQKFCiATqfD3d1d7cddunTh4MGD6e7OyKjPWpM/f35y5cqlLVZt3ryZhQsXotfrmTVrVrrnepqIiAiLQLirqytffvklFy9eTHdBxWQyMWTIEGbOnJlmC0IIIcQ/l4uLA45OdlSuWpRdf47m0MmJGS71G5bF3y+Mvbsv4Z7Lie9+7EXPvg0pX6EQbm6OVn+z6XTJQdroqPh0aQEAFBT1zjGtgIBwYmPSj5yNjIgjPj4RN3dHnF3syZ3bhfh4E4Fp8vymFRwUSXhYDEajAb1OR/4CyZOG370TaPWCbExMPI/8w7Cx0WeYD/hFMhhsqFSlKGMmdGLH3s9Zt2U47TvVwGxWWL7saLpctmm5uDhgY6OnRctKHDwxId05S112HxxL2XJPT9llTVb7x8um0+no0r0O5V8rxMMHj/lpzm6LAG3elNHcU6a/m+71pl3mzuuDg6OtmgIhNeewVuCjCOLjn/4d9Z/Wv4QQmWdx+a9IkSJ8+umnaYtUupS8txlNLvNfVqxYMUaOHKn+28bGhkGDBmV6dNWLknam+KZNm6YbWZc/f351wjBfX99MB5xScxxbWxwcHKhUqRK//PKL2n7AgAFZum04rbS5ca0t2uczGo0sWrQo0xOhpZV2lFzhwk8S8r9Ijo6OfPvtt+lSVOh0Orp27coHH3wAKZNapU7A9KKkDerk1P5lxYYNG7h8+TKkjIbs2rUrOs0XaZ1OR+fOndULBdu2bctworxChQoxduxYq6NVmzVrRvnyyZMqBAcHExeX8RfLjFy8eJH79+/z2muvUbLkkwkVTCaTmnv29OnTtG3blgcPHtChQwero5BTR2dnJiAZGxvLjz/+iKIo6PV6hg4danX/Xqbo6Gg++ugj9u7dCykT8z0tn3KqwMBANaieKigoiNmzZ6sTPqb1zjvvqKPLR48ebbVNUlISc+fOVfOrZ5XJZGLq1KksWLAAgAoVKjBmzJh0/ZCUwGh0dPLsw9akjrjV7mNW2djYZHgB8tatW3h7e2M2mxkzZswzU01oRUZGkjt3bpycnHj77be5e/cu48aNo2LFiuoFlaVLl7J161b1gocEd4UQQrwqcuVyomixPNy5HUiA/7MHSaTepl+8RF5y53axqFMUhZPHLSc2tXewpXLVosTHmzh25Mn36lQB/uHcuR2oLQbg1s2AdHlcSZkkS1EUatZKzr1b7rXkoOXxIzfTTRoGcPrUHZKSzBQvkRd7B1s8C7rj6urA1Su+6VIVAPjcC+b2rUfky+9K3rzWUz28CHt2XaR9y5n8Mm+fRXm+/G707d+YvHldeRwcSVxcxsFEz0Lu5M3nyo3r/oSGZvyd63lltX/8HdzcHBnwwRvo9Tr+3HeFI4eSUwYCvFYh+ffb2dN3rQZatQoX8VD7RkiI5TwZiqJw9vSzU3D9U/qXECLr0t3b0aVLF3UEVVpt27ZV8xAKSz4+PhY/iJOSkpg0adILH4n5LHv37lVHDr/11ltWg0Lt2rXD1jY5/9OyZcte6Gt0cHBgxowZfP/99zka2Nbr9VSqVIlZs2YRHBxMr169rAZnsiL1lusXrXXr1lSpYj3/pE6no1u3bhgMBhRFYf/+/domL0xO7V9mpU1x4OjoSKdOnTI8ZwaDgQ4dOkBKYDSjQF6VKlXw9PTUFkPKSNC8eZOvXt+9e5fIyEhtk2favn07iqJQt27dDEdW+vj4EBERwaJFi9i4cWOGo5DNZjNTpkyxGC1uzfnz59ULNC1atPjbR+2Gh4fTp08fli1bBkDDhg1ZsGCB1c8WrfLly3PixIl0we5Tp07RqFEj1q2znLSkQoUKjB49GlLyKTdp0oShQ4dy6dIl/P392blzJ61bt2bMmDGUKlVKDUZmVmp6g0mTJgFQsmRJ1qxZYzVtwp49e6hUqRK//vorZrPZIn3IiRMnGDhwIACrV6+mbNmymb5QlhXh4eEMHjwYf39/mjVrxogRIzJ8z2SkTJkynD17lqioKNatW6dOqqnVpk0bPvnkE0i5GPI8zgaE0nnDMTpvOMbZgCdpH4QQQoic4uBoS/uO1YmKjOP7WTuIiLBM63Tprwe0bvY1Pbr+wMMHIepIzmtXfLl750lQVlEU9u25zMZ16QdaNG32Gm7ujqxZeZxLfz2ZdyY6Ko453+wkKtL6AIKoyDh+X3qYuNgno3cv/fWANSuPU7iIB3VTJreq16AMxYrnYfvWc/y574pFEO/0yTusWXkcZxd72rZPHrRTuEhu3mhegYcPQli04IDFSMyQx1HMm7sHkymJth2q4abJIfwilfTKT2JiErt3XeThA8vvt39duE9wcCQlS+XDxdXyO2NQ4JPv5HnyuNCydWV87gWz4Kd9FvuiKAp7d1+iSd3JDP1gcbpzmxlZ7R9/l2o1itOyTRUURWHh/P1qoLtSlSKUf60QG9edYt+eyxZ9Iz7exPQpm2lcZxKLFiTPiVG4SG4aNSmPz71g1q8+aTEK+OD+q1b7t9Y/pX8JIbIuXXA3d+7cfP755xY/Ih0dHRk3blymftD/18TGxvLRRx9x8+ZNdDqdeuv7oUOH+Pzzz7M9qiuzwsPD1QBMoUKFMpysqGrVqjRt2hRSRoseO3ZM2ySdFi1aMGDAAKvLyJEjWbFiBVeuXCEyMpKRI0dmq5+kzY2buly8eBFvb281YFyiRAm+++47PvnkE1xdn/+K4WuvvaY+zspt1FlRs2ZNDIaMk/EXL15cTSdx/fr1F9pfXsb+ZVZYWJiaC9rd3R2DwYC/v3+Gi52dndqPLly4oNlasjx58qgpNl600NBQjh07hk6nUwPNGenbty89evRIF3jTjs4+deoUR44csWijtXXrVhISkn8E9OnTJ1vvpewKCAigffv2rF27FlICu6tXr043Cj0j5cuXp1atWhkGu4cNG2YxiZdOp2P48OF8+eWXkNJn586dS6VKlShYsCCtW7dm9+7dNGvWjNWrVz8zLURa0dHRvP/++8yePRtSArtbt25VR3en9eDBAz744APCwsJwd3dn9+7dLF26VA3c16pVi/nz53Po0CHc3d2JiYlhwIAB3Lp1S7upTImNjVXPeSpFUZg/fz579+7F3d2dadOm4eaWfJtcTtDpdLz33nsZBn8zo8/2U6y/8ZD1Nx7SZ/uzfzwIIYQQL8IbLSrQ6e3XOXvmLh1bzWL0yJXMmLaFAb1/5oP+C4mKjKNdx+oUKpyL/AXcaNS4HHFxJj7ov5DRI1bw9ZRNdO88l4lj1/Jmq8o4u9hbpFQoUTIfIz9vR0x0PB/0X4j34MVMnbiBbp3ncv7sPYwZTLpltDWw949LdOn0PVMnbsB78GI+6L+QuFgTgz9qQb78yX/X8+Zz5eMRbXB0smPi2LX06zmfGdO24D14McM/WkpMdDxDh7fCq0zy9y8bGz19BzShStVibN18lnc6fMfUiRuYMGYNXTt9z4XzPrzZujKd3n6SUuxpFi88SOf232a4nD6ZnMZKq0TJvPTs0xA/31D+1+UHRo9Yob7uiWPX4uRsR9/+TdRJyfLkdcHe3si2zWeZMnEDhw9eR6fT0b1HPWrX9WLr5rN0ajObCWPW8PWUTfTo+iMTx65Nvquva21cNUHizMpK/8iM35cepmGtiUyduCFT5ZlhMNjQq28jcudJnpxu3aoTKIqCm5sjn4xqi4urAxPHrqVH1x/5esomJoxZQ6c2s9m6+SzFiueldbvkwL+NjZ7e/RpRuEhulvx6UG3vPXgx4z5fjbNz8u8mp5T/W/Oi+5cQ4uVJF9wlZbRh27Zt1X8PGDCAWrVqWbQRyT/A586dy9atWyElMLlnzx711uKff/6ZTZs2adbKGadOneLkyZOQknKhZMmS6KykNXB0dGTXrl2Q8vqXLFnyzICit7c3P//8s9VlxowZdOvWjfLly2d4a3FWpM2Nm7pUrFiROXPm8Oeff+Lu7s7t27dp1qwZ33zzTaZuUclI2lQFaVMYvEhpA6zWpL0l+3nTB2QkV65camDZ19f3qbeY57S0qQz8/PyoWrUqBQsWzHBp1aqVOqnW3/Ha79y5w5UrVyhatGi6AKDRaFRzgOvSjL62RqfTWYxSvnTpkraJKjg4mI0bN0LKBZratWtrm7w0V69epX79+hw6dAiAN998k40bN2Y6sGuNNtjt6+ubLreu0Whk3LhxXL58mXfffVe9WKbX66lfvz5bt25l165dODg4qP2pbNmn50gLCAigdevW6sWvChUqsGvXrnTnNdX27dvVQO3nn39O8+bNtU0AqFevHnPnzgXA399f3T4p+5FZoaGhRERY3vZ2+PBhxo8fD8BXX331Uv7+5s2bN1vn99rjJ/uQ9rEQQgiRkwwGG4Z/1oZJ07pQpGhuDh+4xpaNZ7h18xF165dh/q/9efe9uuh0Omxs9Hw0vCUferfA0cmOwwevs2PreYoWy838X/vz0fCWFC7swX2fYIs0AY3feI35iwZQpWoxzp+9x87tFyhUKBdfz+5OhYrWU59VqFiYHxe8T7Fiedi5/QJ/nfehbv3SLPp9MI2bWn4Heb1WSZau+JB2Havz4EEIWzae4dJfD6hbvwxLVnxI63ZVLQYReOR2ZvbcHnw8ojX29kZ2br/Avj2XKVgoF5OmdWHshE7YOyTfqfkscbEJBD4Kz3DJKEerTqej2//qMmlaF4oVz8vhg9fZsvEM16760bpdVX5ZMpDyFQqp7Ut55affoKYoCuzafoF9uy+hKApOzvZMnfEuH49ojbOzPfv2XGbb5nMEBkbQul1Vlqz4IN3xyoqs9I+/U+EiHnT7X10A1q05ya0bAZAyad2iZYNp17E6gYERbNt8jn17LuPsbM8H3i34v4X9KOD5ZNLsAp7uzFvYj7e71OJRQDjbNp/j4YPHfDyiDZ+NSU7vlTefZUoSrRfZv4QQL49OySA6dubMGRo1aoSbmxsHDx586kRrL8P06dP5/PPPIWWCGW3uwaioKNq1a8eBAwdo3LgxW7dufeoM8y/CoUOHaN68OQkJCXh6eqrHad68eWoAo2TJkuzdu5fixYtrV4c0xzkmJoZevXqxZMkSbZNnUhQFb29vfvzxR23VM7m7u3Po0CEqVqxoUd67d2+WLl0KGRzvrErdnqOjIwcPHqRGjRpqXVaOwbp16+jatas6kdrq1avp3LmzthlkYrtp6xs0aMD27dtxcXn6H7u07ty5Q6dOnahevTpdu3alRYsWGI1Gi+0+69j5+/tTp04d7t+/b9Fv05Zbe+1ppR7bokWLcvz4cTVVga+vL/Xr18fHx4dChQpx/PjxLOXeDQ0NpVOnTuTNm5e3336bjh074uTk9NTXltExT7tOVqU9LhltXyvt54H2uPCM/gjwww8/4O3tTefOnVm5cqVF8DbttjNaP63MvuY9e/bw5ptvoigKffr04ddff83WF80tW7aoo46f1Q9TKYrCli1b6N69uzp52qBBg/j2229f2CjiHTt20KZNGwD69etnka87s9Ieq3nz5jFo0CBtE0i56NW5c2cePEi+jfLNN99kxYoVTx2hmto3HBwcOHDggNU0RakePnxInTp18PX1pVmzZmzatAknJyeLv1fbt29/akqjw4cP07RpUxITE5k6dSpjxoyx+PzNquf9+5e2X2fwteCpRuy7wOxTyXniPq1ZhllvWE9HI4R4tTwKPs/Nu+kHKjg65CUm1vqkPUL8lz0OjmTQ+7/gWTAXM755DwdHCYKJf4bflx5m3g97GD/5bd5sVVlbLUSOq1i2J+6uT+ay+a8JCgrC3t4+SzGnrLA6chegevXqDBgwgGHDhv3tgd1/Ih8fH/r160dCQgJ6vZ65c+eqx6lfv3506dIFUgKAAwcOVEeZ5YR79+6pI/6KFi3KokWLWLNmzVOX1NQMYWFh6m3Xr4K3334bb29vAMxmM/379+fcuXPaZpni5eVFpUqVADh79iw3bjxJYJ8ZJ0+e5OLFiyxZsoQff/yRpKQneY1S+fn5aYssJCUlqesVKlRIHamYWYmJiRmOas2fPz916yZfAfb19eXs2bPaJk917do1jh8/zrp165g6dao6kvZ56PV6NUDauHFjIiMjURQlU8uff/6Z5UBVdsTFxbFjxw5ISUmiHZXr7Oz83J+JTxvRuX//fjWo1rJly2wFdp+Hoij8+OOPvPXWW2pgd/r06fz4448vLLBLSjqN1O09b7qQM2fOoCgKBoOBChUqaKshZTK+Jk2aqIHdQYMGsXHjxqcGdtPy8PDIMKdzKhcXF4oUKQIp78XU81e58pMvzM+6oOHj46PePVGuXDltdbZFRUXx6NGjZx7ryMhIizQZWTXrjSqc6dOcM32aS2BXCCGEEOIlexwcybtvfc+nQ5cRHp78XT7VwwchbN10Djd3R0qnpPgQQvy7ZBjc1el0fPnllwwdOlRb9Z9nMpkYOXKkeiv/wIED6dixo1pvNBqZOXMmpUsnJ8rfvXs3P/3003ONhsqMo0eP4uvrC0D79u3p3bs377zzzlOXzz77TA0erVq16m+fcCuzdDod48ePV0fThYWF8cUXXzxX4NHNzY0ePXpAysRdv/32W6bPUVRUFPPmzVP/3bNnT6v5Xy9fvqwtsnDt2jU1AFy5cmWrAb3o6OgMU2fExsZmeO4MBgO9e/dWt7lgwYJMH6fExEQWLlyo5gLt0qVLlnKcarm5uamBq+vXrxMYaH1m4X8CPz8/zp49i4eHB40aNdJWA6gXR2JiYrh27Zq22oKvr68aLM0oTUdcXBznz58HoECBAi/lNvy0FEXhm2++wdvbG7PZjKOjI5s2bWLkyJHPTLcSFxfHkCFDKFy4MOXKlVM/izJy48YNtR+mTamwYsUKChcujKurK4cPH06zhqWoqCg1+F6mTBmrAdH169erQWq9Xs93332X5SB1SEgI/v7+2mILkZGRavDYYDCo77XixYurQeTjx49n+LmiKIqa99zFxYVSpUoB8P3336fLPa5dzp8/T8GCyTNs16lTh5s3b+Ln58e6devUi0T9+/fHxcUFT09P9u7dm+aZ07t9+/YzL0Y9S/X8uaieP3M564QQQgghxIvjnsuJWrVLcfL4Lbq9PYcJY9YwY9oWhn2whP91mYufbwh9+zeheInkSaeFEP8uGQZ3SfmxmZUfw/8FiqLw7bffsmbNGkiZMGvKlCnpRvcVK1aMr776Cr0++RCPHz/+qQGL5xUbG8tvv/0GKcGFbt26WQ0QatWtW1cNIF27do39+/drm/xj5c6dm9mzZ2Nrm3yb09atWzO81f1Zunbtqo78mzt3LuvXr9c2SUdRFH766ScOHDgAKX2gRYsW2maQcgt6avBHKzExkZUrV6IoCra2thaBxLQjXR89epRhUNbHx+epeVwbN26s5s/eunUrc+fOzTDQlNamTZtYtGgRAJ6enmoQ/HnZ29urt6YHBASwfft2bRMLy5Ytw2AwULhwYWbNmqWtzlGXLl0iICCASpUqqaMyterVq0ehQsl5xBYsWJDhyHyTyaTmYrW1tVVHUmsFBwerE8d5eXmRN+/L/dK1fv16PvvsM0hJ1bJ79246dOiQqc8Se3t7bGxs8PX15fr160/9LImNjWXlypWQcqEmbRC7SJEiPHr0iMjISPbs2ZNmLUvHjh1TcwFbu+hw8uRJ+vXrh8lkQq/Xs3z5coYOHfrMIHWqevXqQcpr3bx581PfL3/++acazK5du7YaVC1RogR16tQBYNeuXdy+fdtivVQPHz5Uc6DXq1ePMmXKQMo50OYe1y758uVTPyPs7OwoUKAAnp6e5M6dWz1vqRchFEVh7dq1GV4kMplMLFy4MMN6IYQQQgjxz5aaU3rsxLfIm9eVfXsus2XjGf66cJ/adUvzf7/04+0uNTP1/V4I8ep5anD3nyR1krK/W9qJbtzd3fnhhx/InTu3thkAHTt2ZODAgQAkJCQwcOBAHj58qG2WLefPn1eDKTVq1FDTDDyLm5sbb731lvrvxYsXZxhA/Cdq0KCBmtcYYPLkyelmq3d0dLQ6mjatfPnyMWnSJPR6PWazme7du/Prr79aTbFAShqFOXPmMHr0aEgJwk6dOjXDPnDr1i0mTZqU7tgqisK6devUAOqbb75J1arJM52iGel69OjRdBNPkTKycOjQoYSFhWmrVA4ODowfP159/4waNYovv/wy3etJpSgKmzdvplevXpjNZgC++OKL505DkFbbtm3VgOj48eMzDODdunWLyZMnk5SURFBQEPXr19c2yTY3t+QZirUURWHDhuRZbuvXr59hOojixYvzzjvvAHDgwAFmzZqV7rZ3RVH4/fffWbduHaR8Hrz+uvWZZe/du6eOwC5btmyO5eGx5tatW+qIXXd3d3bt2qUGODOrS5cu6sUWa+9FUt473333nToBZdu2bS3ObaVKldTcxT///LPVEdFXr15l8ODBmM1mqxcdHj9+zEcffURYWJiak/vdd9/N0pfYNm3aqP39m2++yfBCxPnz5xkzZgyk/C1I7Q+kBLz79euHTqfD39+fSZMmpUufEhsby6RJk9Rj1adPnxd+MbVRo0bqvixatIh169alC1anXrBcvHgxpKQiEUIIIcTzyZ3HhbWbhzN3Xh/JtyteOjs7I63aVGHpyg85dHIih05OZP/R8cz49j1eq1g4S9+JhRCvllcmuJt6+ykpo/pu3rxJYGBghkG4nJA2zy7AuHHjnnr7tMFgYMqUKWoKgWvXrjF58uR0QaBUBw8eZODAgZlaFi5cCCmjMVNfz1tvvZVh0Mqatm3bqkG//fv3q7eFvwp0Oh0jRoxQA6D+/v6MGTPG4ti6u7urwbnDhw/z559/4u/vT1xcnNqGlDy+33//PXq9HpPJRL9+/ShZsiTTpk3j0qVL+Pv7c/PmTX766SfKli3Lxx9/rE7otnDhQpo3b26xPa2FCxdSr149Nm3ahL+/P0ePHqV3795069ZNDahNnjzZIrBjb29Pz549ISW3cPfu3ZkwYQI3b97Ex8eH+fPnU61aNfbv30+uXE+/DbpWrVosW7YMR0dHACZMmECBAgUYNWoUJ0+exN/fHx8fH5YtW0bt2rXp2LGjmkZgwoQJ9OvXT7PF51OiRAn1OIeFhdGiRQt69erF0aNH8ff359KlS4wdO5bXXntNTXnywQcfqCMgX6TUIHNMTAwrV67Ex8eHx48fExwczMmTJ9HpdOqoR2t0Oh2TJk2iWbNmAEyaNIlatWqxZs0aHj58yNGjR3nvvffo27cvZrOZPHnyMGHChAxz7gYGBqqjJjNK3ZATFEXhu+++U9MPuLi4sGDBgnSfN9aW1NQIpFxs+eijjwC4efMmNWrUYNq0ady8eZOHDx+yZcsW6tatqwZDS5cuzQ8//GDR593c3Bg5ciR6vR5/f39q167Njz/+iI+PDzdv3mTChAlUqVKFO3fupMtznmrFihWcOnUKAFdXV7Zs2ZLudVtbUj9PSRlB/PXXX6PX64mJiaFdu3b06tWLPXv2qP106NCh1KpVSx2VP3r0aIuLM6SkyOnfvz+k/M1q2rQpO3fuxN/fn507d9K4cWP1eQcMGJDhxJDZkXZfzGYz3bp1o3fv3up7bufOnbRs2ZJRo0ZBSj7kadOmaTeTKacDQum47ggd1x3hdECotloIIYQQQgghRA7RKdphPP9Qd+/epWHDhhb5HF1cXDh06BBVqlSxmOn7eWcLfxqTycT//vc/NR1Dly5d+P333zMM1qR14sQJ3njjDTVg9ttvv6kjzs6cOUOjRo3Uuszq1asXM2bMoEmTJly7dg13d3cOHTpExYoVtU0zlJiYSM+ePdXbpIcMGcLcuXPR6XQWs7Vv3ryZ9u3ba9bOmtTtOTo6cvDgQXWEHppj0KtXryylWFi3bh1du3bFbDaj0+lYsWIF7777LljZv1SzZ8/mk08+sSgjZQRmv379MryFOq1SpUqxcOFCGjdurK2y2J8GDRoQFxfH6dOntc0gZTtr165NFxgipc8NHTrUIrev1qeffoq7uzvjx4+naNGiHD9+PMNJoG7cuEGPHj3U4NfTuLq68vPPP9O1a9d0V3j9/f2pU6cO9+/fT3e+nnUuFUVh27Zt/O9//yMiIsKiTuuzzz5jypQpFu+xZ20/VdrPA2vH5dixYzRp0kS9MAJQsmRJvv32Wzp37kyZMmU4cOBAulv+tQICAujRo8dT85kWKVKEzZs3Wz3HqaZPn87nn38OL+j9BrBlyxY6dOgAT9mmtc/VzPr666/VoCAp/XX69OnqnQ0ZqVu3LsuXL6d48eLaKpSUSd2GDRumjhzXcnR0ZMWKFbRv396ib4aHh9OyZUtOnDhh0T4ztH0ps/3UaDTy3Xff8cEHH6R7n5Dymvr37//USStbtWrFypUrs3RhDs378Gl/8xRF4bfffqN///4ZXlgkZTK39evXU758eW1VplT69Q8uBYUDUDGvGxfff1PbRAjxCnoUfJ6bdzdpi3F0yEtMbJC2WAghhBDCqople+LuWlJb/J8RFBSEvb19jt2l+8qM3C1RogR//PEHDRs2VMsiIyOzPQFMZqTetpoa2C1dujQzZ87MVGCXlJGTEyZMUP/t7e3NuXPnLNo8j/3796u3LterVy/dKLZn0U64tXHjRu7du6dt9o/WoUMHevXqBSnnacyYMeo+GAwG5s+fz7BhwyzO1ZUrV9THaTVu3JirV6+ya9cuevfuTdGiRS3qc+XKRefOndm8eTNXr161GtjVKlmyJAcPHmTOnDkWo88rVKjAkiVLuHjxYoZBP6PRqOb2bdmypboPRqORtm3bcvz48Sz1wzJlynDs2DGOHz/O0KFD8fLyUnNCAzg5OdGsWTN+++03/Pz8snw7e2bodDratWvHgwcPmDVrFpUqVbJ4DQULFmTQoEFcu3aN6dOnZ3rfsqpOnTrs2LFDzbdMyoftvn37SExMpFatWhmm2kirQIEC7Nq1i71791qcI1LO8dy5c7l69WqG5zhV6h0IBoPhmSOxX6SQkBBCQ1/MKEuj0ci4ceO4du0agwYNsujvRqORli1bsnXrVg4dOmQ1sEtK//joo4+4cuUKffr0UXPYkpLH/IsvvuD+/ftW8wHHxMQ8cwK0zHpWPy1VqhRjx47Fx8eHDz/8MN1rSeXm5sbKlSvZunUr9evXV7eh1+upX78+W7duZevWrVkO7GaFTqejV69e+Pv7M336dHXSNlJeR61atfj99985e/bscwd2AW6ERFp9LIQQQgghhBAiZ70yI3fFqy115K6HhweHDx/OVhDhny6zo0uFEOLfYtSfF5lxIvli42e1yzG9Sebyvwsh/tlk5K4QQgghXgQZuSsjd8W/iLOz8z9mcjwhhBAvxvQmlbjw/ptceP9NCewKIYQQQgghxEskwV2R4+Li4ggODoaU287T3t4shBDi36FyXjcq5825FBNCCCGEEEIIIdKTKJvIMVeuXOHhw4esWrWKPXv2QMqkPTmZX1IIIYQQQgghhBBCiP+Kf31wd/r06eh0umwvvXv31m5aPEVcXBwjR46kSJEi9OnTh4SEBPR6PYMGDcLe3l7bXAghhBBCCCGEEEIIkUX/+uCu+HtERETg5+en/rts2bKsXr2a9u3bW7QTQgjx6jvhF0K7tYdpt/YwJ/xCtNVCCCGEEEIIIXKITlEURVsohBBCCJFZFX7ZxZXHEQC8ltuVy/1bapsIIV5Bj4LPc/PuJm0xjg55iYkN0hYLIYQQQlhVsWxP3F1Laov/M4KCgrC3t8fFxUVb9ULIyF0hhBBCZMutsCirj4UQQgghhBBC5CwJ7gohhBAiWz6pWcbqYyHEv4NO0fxkMJst/y1UckukEEIIIV42ScsghBBCiGy7HJyclqFCHldtlRDiFeV/dTNKTDh2LjVxPu2J2ZgEQLxyl8clL3DFFQwO8p5P5aJ3wFXvyP3EIGxewhiahIQEDAYDen3OP5cQQgjxPBJJopZDWW6FnqNY3poUdq+obfKfkNNpGSS4K4QQQgghhEgnxvciccG38Cj5Fnz3pDyhvB/+bkvZWVICu2k1zNMYDzsPNvlu0FYJIYQQ/0m5bHPRtUh3DtxayGsF3iCvcwltk/+EnA7uymVeIYQQQgghhBBCCCGEeAVJcFcIIYQQQgghhBBCCCFeQRLcFUIIIUS2HPUNpvWaQ7Rec4ijvsHaaiGEEEIIIYQQOUSCu0IIIYTIlv47zrDzTgA77wTQf8cZbbUQQgghhBBCiBwiwV0hhBBCZMudsCirj4UQQgghhBBC5CydoiiKtlAIIYQQIrPGHrzEtGNXARhTtzxTG1XUNhFCvIJifC8SF3wLj5JvwXdPyhPK++HvtpSdJV3TNs+08MBwpr8zizyFczPklw+wc7TTNkln8cilnN15jhErhlO0YlEATHEm9vy6l0pNK1G4fCHtKi9dwzyN8bDzYJPvBm0VAPcv3WdW929JiE3QVmXow58HU6xi0Swfr5y2a/4fbJi5iQ9/HkzlNyppqy38te8iPw2cpy22ytbB1uIcv6oUReHS/suEBYbRsFsDbfXfIj4mnh/7/x83Tt7UVllV563a9JnZS1v8yokOi2bHT7t4c2ALXPPkzCz1AEmJSfw6fDEA73/bBxuDjbZJjrP2Ofk8ffHGyZusnLAKv5v+ADTp2ZhuE7pqm4lMSH3f3bvok+3Ptnt/+XDpz8u0G9pGLcvKZ/GLEuQTxOz3vqPNkFY0eq+htjqdXLa56FqkOwduLeS1Am+Q17mEtsl/QlBQEPb29ri45MznkIzcFUIIIUS2TG1Ukav9W3K1f0sJ7AohXpqVk1az4/92YU5K0lb9I9kYDeQtmodcnrnUxS2fm1rv4uFsUZfLMxdGO6PFNl5lOr0O9/zu6fYx7ZK3aB5sjAbtqq+cywev8OPA/yMmPEZb9Y/g6OaY7thrFyd3J+1qr5z42AQWDF3I2Z3nUMxmbfULdWT1US4fvELLQS3+lsBuRrLaF4MfPubX4Yvxu+lPmVqlafBufSo2qaBtJl6yh1d9+a7nHIIf/P1zW+Qtlpc2Q1qxcfZmHlx5oK0WfxMJ7gohhBAi28rldqVc7ucbxSeEEM/SZ2Yv5lz81mLUU1LiqxHUTVWobEHGbxvLV4emqMuYjaPwKOiBrYMt3r8Osaj76tAUytcvp93MK6v0615M2v1Fun1Mu4zfNpZCZQtqV33lmJNyNpCYXX1m9kp37LVLl7Gdtau9ehQFc2LOn4sgnyC2/bCDel3qUqRCEW31S2PtczKrfTEsIIyI4AiqtarKx78NpcfU96jYWIK7fzdzUhJmKxcoWg56k3m3fnxpo3ZT1WhTnXzF8rLz//545f4W/1tJcFcIIYQQQgghhBAiixRFYf9vBzDFJdCgaz10Op22ySslJiIGc5IZOwc79DYSLhLWObk7Ueet2pzddY4bxzOX6kXkLHm3CiGEEEIIIf4Wj+48Yu77PzLYawiDvYYw/Z1Z3Dx1C+20IItHLmVopeHcv3Sf+5fuM7TScI5vOEFCbALTOk1X6/7tQv1DWTbmd4aUG8pgryGMqjeGP387YDFyKj4mnm/e+y7DY7Jr/h8M9hrCrvl/qGWLRy5lTKPxBD8I5o8FexhZ+3MGew1hSLmhLBvzOxFBERbbsEZRFP74eTeDvYbwRfNJhPiGaJtkWlJiEktG/cZgryEsGfWbxf4pisLqKWsZ7DWEZWOXW9SFB4azdtp6Pqk+Qn39v3z8K6EBYWqbtExxJvYv/ZMxjcerfXBq+6+4fuy62gdT+9s3731HfEy8xfraY53679Qcwxtmbkp3rK095+TWUziz/Wy6UZa75v/B0ErDubjvEvM+/JnBXkP4pPoIjq45ZtEuJ929cI+hlYYztNJw7l64Z1Hnd9OfETVH8Un1EdxLU2dOMnNx/yWmvzNL3ccxjcdzasvpdPuY2v7M9rNMbj3lqe0Xj1zKYK8h/LXvosX6pOR2Huw1hMUjl6r/Hlb5E26cvEmIXwij6o1hTKPxhAeGq+s8uvOIJaN+U99PwyoNZ/WUtcRFxaXZ8tM99g3h7I5zlK1XlgKlCgAQFRrFpFZTGNN4PGGPLPveuV3nGew1hEmtphAVajkJ7aEVhxnsNYRDKw6rZf63/PlxwP8xrNJwi2Ozf+mfmOJMFuun/ZzMTF9MKzwwnDGNxqvtj284wWCvIRb9/mX13dCAMJaN+V3d56e9j1OPT9pzuGzM7+napu7f4pFLefzwMT8Nms8HZT5S9yE1F/XtM3fUfvtBmY9YMHRhum2lfl4G3gtkw8xNeFf8mMEpn8fWzktGTHEmjm84wYQ3J6vHU/v5Q8rzTes0nYTYBPW8pPbz1M9z7Xviec7VvQv3OLXltLrOB2U+4of+P/HobqBF+1QVm1TAJbcLR9YeTbdN8fJJcFcIIYQQ2XL4YTAtVh2gxaoDHH749+cCE0K8Gu5fecB3vedy++wd6rxVmzK1SnP3/F2++d93HFz+JLih5eTuRIOu9chdODc6nY6qLarQoGu9f0WO0Ke5d9GHKe2/4uzOc1R9swqVm1UmIjiClZNWs+7rDekC4lkVFxXLks9+Y8PMjZSoWoI6b9XG0dWBw6uP8l2vOUQ+jtSuolIUhT2/7mP9jI3kK56Pj5d441HIQ9ss02wMNrTzbkOeInk4vuEElw9cVusu7rvI/qV/kq94Ptp82ErNb/rg6kO+7jyTPb/uxcndiXpd6lKqeklObz3Dl22mWAQfAWIjY/lx4P+xavIa4iJjqdGmOpWbVebhdV++7TmH3Qv2ZPmY6m30VGxakfINktNpFK1YlAbv1qdQueQJ/yKCIvi+7w+smrwGU2wCdd6qTY021Ql6EMyCoQtZ/sXKdLc4m+JM/DJsIbdP36bOW7UpWKYgnqWTg4ipwU5twC6j8udRvHIx2gxpRUJsApu/3aIG+uJj4lkzdS1RoVG0GtySYpWLQUpgfsOMjfw44P+499c9KjerTJ23amOKTWDh8EXp9jEpMYnlX6xkwdCFPLobSJXmlanRpjqRjyOtts+sXJ65qNu5Ds65nDEYDdTqUJOa7V/HaG8LwMX9l5ja8WuOrTtOwdKeNHi3PvmK52Pf4v1M6zQ90xcnrh66StijMF6rX14d6erk7kTxKsUI9QvF/1aARfsbJ5KDiI99H1s8hznJzJUjV3F0daBUjZIAnNt5nsltpnLpwGXK1itHg3frU6V5ZcIfhbNq8hrWTFuXYVDtWX1Ry2hvS832r6vtcxfOnZxvt2lF9Db6F953M3Lvwj2+bDOFw6uP4u6Zi3pd6lKwtGe697GiKBxbd5zJbaZycf8lSlUvqZ7Dw6uPWn3PA9w6fZuv3prB3XN3qN2xFl6vl8Lvpj/zBs9n/fQNzH7vWxLjTdTrUpc8hXNzZvtZfh6ygNjIWIvtmOISWDlpNbvm/0GJysWp81Zt4mPiWTV5DfM+/Dlde63YyFgWjVjC4pFLCfELoUab6tRoU52Au4/Sff6UruVF1RZV0Ol06nkpXctLu0nV85wrRVHYMHMTv36ymDyFcqv7f+nPy8zsOludXC8tj4IelKpRkisHrxBw27Kfi5dPgrtCCCGEyJb+O06z514ge+4F0n/HaW21EEJYFRcVR5HyhZn652T6zOzFJ8s/ZuTqT3FwtmfLd1ut/pgkJejQdXwXStf0wmhvpM2QVnQd34XchXNrm/6rJMQmULVFFb46Mo0Bc/rx4fxBfLZmBLYOtpzdcY7HmQxGZSQmIpbwoAgm757Ah/MH0WdmLybvnYhXTS/8bvpz9/xd7SqQEhQ4s/0s66dveCGB3VS5C+em04gOoMDGWZuJfBxJaEAY62dsBOCd0W+pzxMfE8+aL9cS6h9Ku6FtmPjHF/T6qgefLP+YQT8MIDYqjqWfL7MIUO9dtI9rR69Tu2Mti2M6as0IHF0d2L1wL4EZjFjLiNHOyJsDmtO0VxMAarSupuYsNSeZ2Tp3O7dO3aJ+l3pMOziFPjN7MWBOP6bun4xXTS8OrzrCkdVHLbapKAoueVz5fP1n9JnZixErP6FE1Zc327xOp6Px/xrhVdOLq4evcWLjSQAOrjjM1cPXKN+gHI17NFLTEfy19yJ7ft1HniJ5mLBjnEVfKlevbLp9PLfzPIdXHcGrphdfHZrCB/MGMWBOPybt+oJ8xfNxZM1Rbp+5o7bPrCLlC9NtQlcKlvbENa8rnT9/i7dGdsTR1YFQ/1BWT1mLKd5E7+k9GbP5c3pMfY8xmz/n7c86EXgvkBWTVj1zBGZqQNZob7TItavT6ShfrxyKonDvLx+1PC46jgdXHqDT6UiITSDg9iO1LiI4Ap+L9ylcrjC5C+cmIjiSTd9uxsHZnlFrRvDh/EH0mPoeH8wbxOQ9E8hTJA/nd18gxM/6+/5pfdEaR1cH3hrZUW1fuqYXPaa+x5sDmmNjsHkpfTc2MpZVk9cQGxVH7+k9mbhrPL2+6qGel5iIWHbO+wNTvImHVx+yZupaHJztGbn6Uz5Z/rF6DntP72n1PQ8Q/CCYKs0rq/vw6YrhNOnRmJiIWHb/spe+s3ozZvPnyc+76XO8anpx7y8fi/MIEBkSxb0L9xi2xJtPln9Mn5m9mPrnZMrWKcPlg1c4teXp34cPLj/E2Z3n8KrpxdT9kxkwp59Fv98waxPXjlwHoH6XerQZ0gqjvVE9L/W71NNuElL65POcK1OciQdXHjBi1Sd8svxjen3Vg4l/fEHdznWICo3iopXR8nobPUUrFCUmIva53qPixZLgrhBCCCGyxSfiyQzMaR8LIcTTOLo68O4XXSxG3JaqXpI3+jTN8Mfkf5mjqwNthrTCziF55CFA4XKFKFO7NDERMcSERVu0fx71u9Qlb7G86r8dXByo1CQ5GKQdgZjq/K4LLBy+iLxF8+K98MMMA7s3Tt5kWOVP1FuErS3akabVWlalztu18bvpz+6Fe9n+4w4Cbj+iaa8mVEozgdCt07e5eeoWJaoUp1nfN9TRvABVW1ah8XuN8Lvpz/VjNyDltvkz28/h4uFMW+/WFse0eJXiNOnZGBSFR3efBN+yK/BeIGe2nyVvsby0G9oGo71RrXPN68rbn3XCYDRwZM1RYsIt/5ZWalrR6sWLPjN7Me/Wj7Qc9GamylP9NHBeumOfdtGmL3BwcaDz529h62DLth92cHrbGXb8uAPnXM50GfsOdo52kBIgOrTyMIqi0OnTDmqagtRtdB33Do6uDhxbd5zYyFhM8SaObzyB3kZP+2Ftcc37ZGJWj0IetPNug4OLAz5W0otkx4U9fxHkE0SNNtWp3amWGpjW6XQ07dWE1xqW5+qhazy4+lC7qoWYiBgCbj3CxcOFXAXcLeqKViyCk7sTt8/cxhSfHCQOCwjD/1YAVd+sgmteV64evaaOzvS97keoXyilapTEztEO/1v+xETEUvXNquqo6FTOHs54eOYiLiouXV/JCTnRd61JDaJWaVaZWh1rWpyXBu/Wp1SNkoQFhhMVGs3RtceJiYil5aA3KVU9eaRzatvanWpRo011/G76c/XwtTTPALYOtjT+X0N1H3Q6Ha81Kg8p56xCo9fU53VwcaD066VQFEU9h2m90aepxaSXTu5OdPy0AwajgdNbz6RL45IqIjiSY+uPYzAaePuzTun6/btfdAGF50p3kJ1zVaVFFUpWexJ8tzHYUK1lVXjK53/qBJja4Ld4+SS4K4QQQohsGVGrjNXHQgjxNAXLFLQaCCxVoxQ6nY57F32y/MP238zRzQlnDxeLMqOdEedcziQmJJLwjFGGmWHtlm3P0p7aItW5XedZNHIJKNB13DsWgWEtnV6He353cnnmynBxcHGwWCdteobdC/ZwaMVhilcuRtuPWltMXHUrJU9z5eaV021Dp9NRtm7y36arR5IDPSG+ITz2fUyhcoVwy28ZlAPoMLw9M09Op3Kzytqq5xb88DHRYdGUqFIcd00gEKBgGU+KVynGozuBhPiFWtY95Rw8D0c3x3THPu3iUTBXusm0UtMzhAeGs/DjRcRExNJlXGeL1xYeFM7Da76453en1OulLNYnJXBVsExB/G8F8PhhCDHhMfjd8MejoAeeXun3sVbHmnxzZiYt+jXTVj03c5KZ6yeSg/w1WldPt59GeyOla5Um0ZTInXNPH40YFRJFRHAEHgVzYeeUHOBOlcszF4XKFOT+5QdqoDzwXiAx4THUaF2dAiXy43fDj9iI5Nv3L/15GZ1eR4WUkbVl65RhxrGv6PV1DxJiEwi8F8jZnedYM3Ud0zvPVHPEvgwvq+/ev3QfRVHwqlnK4gINKX125KpP+XzdSBzdHPG/5Y/R3kjZumUt2pEyorRG6+oAXD+efK5TOedyxi2fm0VZKk8vTxzdHC3K7J3tLf6dymhvpGKTitpiCpTMT6FyBQm4+yjdqOFUgfcCCbwXRPEqxShYJv3xKVimIG753Lh3wYeo0KxdtMvOuSpUtqDF5yqAez43bNNc/NJKrX/s+5j42ARttXiJJLgrhBBCiGz5smFFbg5szc2BrfmyYfovukIIYY1HQQ91xF9aTm6OGO2NRIdEWx0t9V/l4ZkLO0frP7LNSWZisnnnhK2DLa65LYPHz3Js3XFMcSYUReHA8oNPvY299OteTNr9BV8dmpLh0ui9htrVktMzfNoBUoI2nUZ0TJdfOXXCo6uHrrJs7PJ0y4mNJ9HpdAT5BBEfE09YYDgJsQm453e3GLWbk0JTAimepT3TBVAAbO1tccnjQkJcAnHRlhN6ueW3Hox6Xn1m9kp37NMuI1Z+goumL+h0Ohp1b0jJaiVQFIVKb1RUA2ipYsJjiIuKIyE2ns3fbEl3HtZOW09oQBhx0XGE+IcQHhhOVGjUU/v2i2aKNxEdEo1Op+PoumPpXuOyscu5kRL8fXjVV7u6hYTYBEzxJvR6fbpzaudoR4XGrxERFEHAneQR4FePXMcljwvFqxSjULlCaoAtJjyGO+fukL9EfgqUyq9u49HdQKa/M4thlT/hi+aT+PmjX9i7aB9RYdFPDbi9aC+r76aODs1XPJ+2yoIpLoHwwAic3JwyDNS65HZBb6MnJjLG4iJhnsK5MwzYZkVGz21jtMHOwe6po6pjIpJf02PfENZOXZeu/23+dguJCYlEBEcQFmAZgH2W7JyrZx33pzEnmiGLOcrFiyXBXSGEEEJkm1cuZ7xyOWuLhRBC/Is553Jm6OKP8KrpxcV9lzi765y2SbYlJSZxLWX2eHOSmdPbzqSbDCjVjZM3ObzqSLrl/O4LWZ4c7e9iY7TBYGvQFv8jBN4LVCdOun36Ng8zSFsQExHLsfXH052Hw6uO8PjhY23zv4WiKFzcdynd6zu86ki6W/mfV6kapdDb6Ll9+jbxMfH4XvOlYGlPXPK4UqZ2aeKi4/C74Ufwg2D8bwVQvEox9cKF301/Znadzd3zd6neqhoD5vZj2oEvmXv5O77cO5HilSxTNfwT/BP7rtHWiE6fPsiZXXobPfpsbjfUP5TDq4+m63/H1h0nKjRK2/yF+ieeK5E9EtwVQgghhBBCvHRhAWEkWLmNMzo8BlOcCfcC7i91dJrIuv9N6U75+uXo9El7DEYD66dvJMgnSNssW/7ac5Eja45SuHwhCpcvxJE1R/lrj2U+5tR8p4N+HMC8Wz9muHyy/GPsHO1wzeOK0d5I2KOwl3Yrca6CuQDwv+lvNdAcFxVHiG8INjY2GIz/vKBLbGQsG2dvJjYyjmqtqhIbGZfy7+S0AqTcOm/vbI/X66X47sLsdMc/7VL5jUq45HHF0dWREP9Q4mNeznkw2hlx8nDCaG/k8/WfpXtdaZc+M3tpV7dg62CL0c6I2Wy2ek4LlMpP/hL5uX3mDv43/fG94UeJKsWxc7DF06sATu5O3L1wD5+L90mITaDam1XU0ZanNp8iKjSKNwe2YMDcftRoXR2PQh4Y7Z7kUH1ZXlbfzV8yedRy4L2nT2RotLfFLZ8r0eHRFrmh0wrxDcGcZMZoZ7Q6gjW7okKjrD63Kd5EZEgU9s726VI8pHJ0dURvo6dWh5r8380f0vW71GXOxW8pWrGodvWnelnnSktv0EMOHGeReRLcFUIIIYQQQrx0/rf80/04VhSFq0eSJxkqXdMrR36U/9vZOdrhUcgDU5yJaM1tweYkM/cuvriJb2yMyXkxS71eikb/a0h4YDjbf9qZ4cjarArxDWHDrE3Y2NjQ7YuudPuiKzY2NmyYtYkQ3xC1XYmqxQG4fuyG1YCGlkdBD3IX8sD3mi/hj5JTOqR1ctMphlUazu6Fe3HL54ZzLmciQ6LSpQmJfBxJQCYnXctTOLcazAtLSSORVsCdR/he8yOXZy7cC6S/3fvvpCgKh1Yc5trR61RvVY0+M3tTvVU1rh29zqEVyROokXIrfGou2bTnJyPO7k4ULONJiF8I/rf8tdXcOn2bT2qMZMWEVZiTzHh6JU/QFv4ofVDt7rm72iKr9DZ6ilcqhinOxL0L97TVWeLs4YxrHldC/EKJj04/eZaTuxPFqxTD94YfV49cIyY8Ru2rrnlcyVMkNzdO3uTk5lO45nWlYJnkyalIk2rE6/XkHORpPX74mIfXrI+azgkvq+8WrVAEnU7HrVO3032GmJPMLPBeyNgmX/DodgCeXp6Y4kxcP3bdol1q27/2J18AKq6ZjO5FSYhN4I6VPhd4L4ggnyCKVihiNW0DQJ4iuXEv4M79y/eJfPxiR+i+rHOVKjXFTe5CuV9aihthnQR3hRBCCJEtB+4H8caKA7yx4gAH7r/YEVtCiH+vyJAods3fbZGn9drR6xxeeZi8xfJSsUnyxEJPYy2AKZKDB4qicGb7WTVIoigKZ3ee4/wfF7TNs02n09GiXzPyFsvL8Q0nuHzgsrZJliUlJrH9p50E3guk0f8aUur1UmoQOfBeoEUQuVSNUhSvXIwDyw9yZvtZiwCvKc7Eb6N/54MyH7F1znYAXHI7U71VNat9MDQgjD9+2U1SoplS1Uti62CLW15XAu4EcPPELYvt7vhpFxFBEWqZVliaIGS+4vmo0aY6QT5BbJ2z3eI5I4IiWD9jI4mmROp1qYvzPyzNkc9fPuyctwvnXM60HdoGOwdb2g5tg3MuZ3bO24XPX8kXDOwc7aj/bj1iImJZNXkN0WGWk0HdPnuHT6qPYGLLyQT5BGG0N1K7Uy3MSWa2/7DDon18bAK75v9BTHgMpWt6obfRk7tQbgDO7DhrMWL49tk7HPj9oPpvrfiYeIv21VtVwy2fG5u/3cLts5aTpkWHRfNtj+8ZUm4op7actqjTcnR1pIBXfsIehRF0P1hbjU6no3y9ckSHRbPz/3bhksdFDeA6ujlSslpJfK/5cuv07XTBwNTR6Od2nbcIdEaHRbNq8hpiUiZiy6y0fTGrXlbfLV65GMUqFeXC3r84t+u8xfv48sErXNjzF665XchbLC/13qmDo6sDu+b/YXEOFUXhxMaTnNl+lgKl8lOxac7NBbHn131qmhJSzs2m2ZtJSkyifpd6GY6ydsvnRp1OtQi4/YhNszdbHE9FUTi97QwflvXmm/99l+49FBWa/iJTWi/rXKXyve4HKRMAir+XBHeFEEIIkS0Dd51h//1A9t8PZOCuM9pqIYSwyqOQB+d2nWNMo3EsHrmUb977ju97zyU+JoG3P+tELs/k20sz4ulVAEVRWDlpNasmr1GDK7vm/8FgryEsHrnUov39S/cZWmk4YxqNtxgxnFH5q6x66+o453Lm8KojTHxzMktHL2Nah69Z+PEiKr9RCaO99aBDduTyzEW7oW1QzArrZ2xURx6munn6FhNaTGZ0w3FPXQ4uPwRp0jEULO1J6w9aotPpkoPI/ZtToFR+i/QMTu5OdJ/UDSc3J34Z9isTW37J0tHLWDB0IaPqjebImqN4lipAvc51IDUY3b855eqV5ciao2ofXDB0IV80n8jDq750GN6OElWL4+DiQN3OdVDMCvM/WsA3733H4pFLGdNoHOd2naOSleBR6gzyR9YcZdGIJfy19y/0NnraebfBq6ZXuucc2/QLbp26Re2OtaxOKpeRxSOXMthrCLvm/5Gp8lSLRy5Nd9y1y6xu3xD5OFJNxxATEUuH4e0oWNoTgIKlPWk5+E1iIlLTNSQHG19vW4NG7zXk+vEbfFZnNP83eD7Lxi7nq7emM7PrbGIiYqnftT55iuZR2zd4tz7Xj9/g8wZjWTB0IYtHLmVc4/Fc3H+JBu/Wp1qrqgCUrVuGgqU9uXb0Ol80m6h+bszsOpviVYrj4mEZrLJ1sCV3kdxEh0WzZNQy1s/YSHRYNHmL5aXbhK7ExyQws+tsvnprOsvGLuf/Bs/nszqjuX78BuXqlX3mBSa9jZ7X6pfHnGTOMP9wyeolcM3rSnxsAgVK5LeYqK5EleRRvIqiULFxBYtgYPXW1XB0deDYuuPqe/inQfMZUWsU4UHhVG5WmYTYBMKe8ZllrS9mVU70XWuc3J3oNqErDs72/DLsV6Z1+JplY5fzzXvf8eOA/8PWwci7X3TBwcWBwuUL02XsO8RGxTGz62y+ee87lo1dzrQOX7Nk1G84ONvz7hdd1SD5i6bT6YgMjmBS6yn8NGg+C4Yu5PMGY7l+/AYNuzegcrNK2lVUqZ8/FRq9xpE1RxlVbzQLhi5k6ehlTGz5Jb8M+xWdTkfTXk3UHMyueV1x9nDm8sEr/PrpYg6tfDJiPq2Xda5ISUFx+8xtnNydKFqxiLZavGQS3BVCCCFEtjxIM0N72sdCCPE0ZWqVZtS6z8hfqgDHN5zg5ulbVGxSgXFbR1OtZXIw52nqvl2HGq2r8+jOI/Yv/TPD4Mp/UcHSnoxc/SkVm1Qg+OFjjq45RkK8iYFz+9NmSOt0t3m/KNVbVqPSGxUJuP2I3b/ssQg+KGaFsEdhhPqHPnWJjYxV0zEAdBrRwSIglquAO+2GtgUFi/QMxSoVZeyW0TToWo8w/1COrjnGme1nsXdx4O3POvHZmhF4FPJQt+Pg4sCQnz/g3S+6oDfYcHzDieTRfiXyM2TBB7QY0Fw9Tg27N2DQDwPIXzI/N07e5MSmkxSrXIxR6z7D6/VS6jZTFSpXiPYft1NHEZ7eljya2DWvK8MWfcS7X3TB6GCrPmfeInkYMKcfvWf0fCm3NseEx6Q77tolxC+UpMQkNR1D+QblqN2plsV2GnSth1dNL4v0DDYGG7pN6MqAOf3IXyIfF/b8xeFVR3h41ZdKTSsyat1Imr//hnpsbQw2vDe5GwPm9MOjoAdntp/l+IYTGB1s6T29J90nvYuNITn9h0tuFz5eOpQGXZNHBx/fcIJH9wLpNqErPaa8h9He8tjpdDpaDW5JiaoluHv+LvuX/qleBKrWsirjto6mUtOKPLzqy+FVR7iw5y/yl8hHn5m9GPzTQBxcHCy2Z035huVxz+/OX/suWoySTOWWz42iFZIDX6VqlMTO0U6tK1imIPZO9tg62FKyWok0a0GR14owcvWneNX0Iuh+MEfXHMPnog9vjejI5+s+4/U21SET6Sgy6otZ9bL6bvEqxRm/fRwNutbD76Y/h1cd4fbZO1RvVY3RGz+neEpAXKfTUbdzHb7YPpZKTSty++wdDq86QuC9QBp0rcf47eMoX7+cdvMvjNHeyIc/D6ZpryZcPnCZM9vP4lHQgwFz+tFtQle1z2bEwcWBwT8N5N0vumDv4sCZ7Wc5uuYYYf6h6n6l/Tvols+Ntz/rhLO7E+d2nufExpPEx6RPBcJLPFePHz7m7vl7lG9QjnzF82mrxUumU57nnS2EEEIIkWLi4StMOpJ8C+6E+hWY2OA1bRMhxCsoxvciccG38Cj5Fnz3pDyhvB/+bkvZWdI1bfP/vIZ5GuNh58Em3w3aKiHEv5SiKGz5bit//LyHT5Z/nC5IK/59Fo9cytmd5xixYniWJzz7N/nztwOs+nINQ3/9iPINMg6k57LNRdci3TlwayGvFXiDvM7/zfdIUFAQ9vb2uLg8uVj5IsnIXSGEEEJky8QGr3FnUBvuDGojgV0hhBBC/GfodDrqvl0Hp1xOnNx06rlGxQrxqol8HMnB5Yeo3rIaZeqU1laLv4EEd4UQQgiRbSXcnSiRkhdMCCGEEOK/Im+xvLT9qDWntp7G/6a/tlqIf51zu84TFhhOqw/efGYKCvFySHBXCCGEEEIIIYQQ4jnVfbsOJauXYNvcHSQlJmmrhfjXePzwMbsX7qX1h60o8ppMpPZPITl3hRBCCCGEEOlIzt2skZy7QgghhCXJuZtMcu4KIYQQ4h9t//0gGi//k8bL/2T//SBttRBCCCGEEEKIHCLBXSGEEEJky8Cdpzn4IIiDD4IYuPO0tloIIYQQQgghRA6R4K4QQgghssU3MtbqYyGEEEIIIYQQOUuCu0IIIYTIls/rlLP6WAghhBBCCCFEzpLgrhBCCCGy5Yv6r3H/w3bc/7AdX9R/TVsthBBCCCGEECKHSHBXCCGEENlWxMWBIi4O2mIhhBBCCCGEEDlIgrtCCCGEEEIIIYQQQgjxCpLgrhBCCCGEEEIIIYQQQryCJLgrhBBCiGzZc+8R9Zfto/6yfey590hbLYQQQgghhBAih0hwVwghhBDZMnjXWY76Puao72MG7zqrrRZCvMoURVsihBBCCJElOnTaIvECSXBXCCGEENniHx1r9bEQ4l9AfosJIYQQQvyj6RRFLscLIYQQ4vlNOXqF8YcuA/BlwwqMq/eatokQ4hUUff8M8SH38CjVmaRVSdwreBcdOgpG67hdbDfnc9ugM9pj1Bux19trV/9PSVKSKOJQlLz2eTkVcgJbvZ22yQsXHx+P0WhEr5fxOkIIIf6ZTOYE3i7chcO3l1AqT2083cppm/wnBAUFYW9vj4uLi7bqhZDgrhBCCCGyzS8qecRuQWcHbZUQ4lWXCBjgTOhpIhMjaJSrPuHmGFY/WEGt3HXwi/WlrWd77Vr/OSazCQXzSwnsAvj7++Pu7o6Dg3zuCiGE+OczK4nodQZt8X9CTgd35TKvEEIIIbKtoLODBHaF+LdK+R1mlxK01BvsJHeeFUa98aUFdoUQQohXzX81sPsySHBXCCGEEEIIIYQQQgghXkES3BVCCCGEEEIIIYQQQohXkAR3hRBCCJEtf9x7RJ3f9lLnt738ce+RtloIIYQQQgghRA6R4K4QQgghsmXwzjOc8AvhhF8Ig3ee0VYLIYQQQgghhMghEtwVQgghRLYExsRbfSyEEEIIIYQQImdJcFcIIYQQ2TK2Xnmrj4UQQgghhBBC5CwJ7gohhBAiW0bXKUfAR+0J+Kg9o+uU01YLIYQQQgghhMghEtwVQgghRLbld7Inv5O9tlgI8S+jKIq2SAghhBBC/I0kuCuEEEIIIYQQQgghhBCvIAnuCiGEEEIIITJFp9Npi4QQQgghxN9IgrtCCCGEyJYddwKouWQvNZfsZcedAG21EEIIIYQQQogcIsFdIYQQQmTLh3+c5XRACKcDQvjwj7PaaiGEEEIIIYQQOUSCu0IIIYTIluDYeKuPhRBCCCGEEELkLAnuCiGEECJbxtUtb/WxEEIIIYQQQoicJcFdIYQQQmTLqDrlCB7akeChHRlVp5y2WgghhBBCCCFEDpHgrhBCCCGyLbeDLbkdbLXFQgghhBBCCCFykAR3hRBCCCGEEEIIIYQQ4hUkwV0hhBBCCCGEEEIIIYR4BUlwVwghhBDZsvWWH9UW7abaot1sveWnrRZCCCGEEEIIkUMkuCuEEEKIbPlozznOB4ZxPjCMj/ac01YLIf5FFEXRFgkhhBBCiL+RBHeFEEIIkS0hsSarj4UQ/z46nU5bJIQQQggh/kYS3BVCCCFEtoyvV97qYyGEEEIIIYQQOUuCu0IIIYTIlpG1yxI6rBOhwzoxsnZZbbUQQgghhBBCiBwiwV0hhBBCZJu7vRF3e6O2WAghhBBCCCFEDpLgrhBCCCGEEEIIIYQQQryCJLgrhBBCCCGEEEIIIYQQryAJ7gohhBAiWzbf8qPyr39Q+dc/2HzLT1sthBBCCCGEECKHSHBXCCGEENnivfscF4PCuRgUjvfuc9pqIYQQQgghhBA5RIK7QgghhMiW0LgEq4+FEEIIIYQQQuQsCe4KIYQQIlsmNqhg9bEQQgghhBBCiJwlwV0hhBBCZMsnNcsQ+clbRH7yFp/ULKOtFkL8iyiKoi0SQgghhBB/IwnuCiGEECLbnI0GnI0GbbEQ4l9Gp9Npi4QQQgghxN9IgrtCCCGEEEIIIYQQQgjxCpLgrhBCCCGEEEIIIYQQQryCJLgrhBBCiGzZcMOXCgt3UWHhLjbc8NVWCyGEEEIIIYTIIRLcFUIIIUS2DN1zjivBEVwJjmDonnPaaiGEEEIIIYQQOUSCu0IIIYTIloh4k9XHQoh/n4SkBG0RieZEbZEQQgghhHhJJLgrhBBCiGyZ1LCi1cdCiH+X/HYFeD1XTcyKGR063ivaE3eDO/VzN8CsmLXNhRBCCCHES6BTFEXRFgohhBBCZEVsYhIADgYbbZUQQogc4O/vj7u7Ow4ODtoqIYQQQvyDBAUFYW9vj4uLi7bqhZCRu0IIIYTINgeDjQR2hRBCCCGEEOIlk+CuEEIIIYQQQgghhBBCvIIkuCuEEEIIIYQQQgghhBCvIAnuCiGEECJb1l1/SNmfd1D25x2su/5QWy2EEEIIIYQQIodIcFcIIYQQ2TJs73luhEZxIzSKYXvPa6uFEEIIIYQQQuQQCe4KIYQQIluiEhKtPhZCCCGEEEIIkbN0iqIo2kIhhBBCiMyac/omn+y7AMA3b1Tho+qltE2EEP8hpsQYkpIStMXiBQsKCsLV1RU7OzttlSoxKR5nx/zaYiGsUlCIivLFYHAkOjYAZ8eCap3ZbAKdHh06IqN9AQVnR0+iYwKws3PHbE7A3i43imImPiECezt3FMWMKTEao8GJ+IRwEkyRuDoXRVHMRMX4owMMBidsbV1ISAhHp9NjNDgRHfsIRTGj09ng4lQIRTETHROAk2MBIqMe4OJchMhoX3Q6MNg4YDA4YKO3JSkpnriEcJwdPYmNCwbA3i4XUTEBoJhxcS5CdIwfehs7kpLicbDPjckUTVx8KEaDE4qShJNjAaJjAzEaHDEYHNChJ94UjtlswsmhAPEJYSmfcfHo9bbYGpxINCeo+2s2J2A2J5FkjsfJoQDRsX4oZkCnw9nRk6gYP5wdC2JKjMJocCYm9hGODvmJivHDbDbh6lyM6NgAksM0T0I19na50Ols0KHDlBiLKTE6pX1R4hPCMStJ6HR6EhOj0aHHydGTBFME8QnhGA3O2NgY0ettU15jIkaDAwmJ0ejQYVYSSTTFoNMlrxcV44eiKOjQ4exUMOXfZnQ6PWZzIgoKtgZHdDoDiUlxGI1O6ACjwQmTKZqExGiA5PMQ/5jExFhsbOyws3UjMTEOvd4Gvd6AYk4iLiECB/tcJCUlYDJFgU6Hk0MBQiNu4WCfG3vbXETF+OPkWICY2Efo9Ub0OhsMBkf0eiM6dGrfCI+6h71dLmwNzmq/M5NEXFxI8rlIaRcdG4Ct0RUdoNMbUJTkNgaDY8q+6DAlxmBKjMHJIR9J5gQURSHBFIGTQwESTFHY2NiSmBiLWUnEZIrG1bkoUTH+an+NiQvE0T4f0TH+KIo55d2lw0ZvRKc34GifRz238QnhJJlN2OgN2NjYA6AoSSjmJPQ2tijmRCKiH+Bon4+kpLjkc4gZPTbYGOxISkrAySFfcl9OCMPNpbi67eT3YBQJCVGYlUScnQqlvIcfqP3B2dGT+IQw7GzdMSVGERcfrr5fomL8MBqckvtWUhxJSQk4OxYgKsY/pT8n/9+sJJKQEIHexg73NM+floJCTEwger1RXS86JgBHh7womImJDcbJsYB6Th0d8mI2JxGXEIKjfV6iYwNTXmvy5wtAfHwoRlvX5HViA3BySO4nDvZ5QKcnJjYQR4d86HU2at9Lfe6Y2CCcHAuQYIrAZIrG0SEfkdG+ODl6opgTcXEupN2FLImLDwPA1uiMXm8gODgYe3t7nJ2dLdopKMTHhwMK9na51H8rihkHew+Ltk8jwV0hhBBCZFtCkhkAWxs9oaGh2mohxH+IoiRy5fZP2mLxN/Bwq4xn3ibaYiEyFBJ+Cf+gfTg6eBIT66+W63Q6QJcScChAkjmOhIQwjAYXTImRFtuwM+Yi3pT8XcDW4EZCYjgGGyeMRhdi4wIAcLDLh06nJ8EUQWJSDEaDC4qSSGJSLI72BQAdoBCT0t7RvgAxcQE42hckJs4PR3tPEpOSg2vmpDjMSiIGm+RAcUxsADY2jiQlxQJKyvYgJi4AB7v8JCXFYmPjSGy8P6BDpzeimE3YGl1JMEVga3QjKSmeJHMcOp0Onc6AvW1uYuICsDW4YGNwRIeOxKRYkpLiMJtNKJjR6wzobexTgo8OxMQGYG+bG73eCDqIiQ3Awb4AsXEB2BrcSUgMw8EuH7HxgdjbJQegYuL8sbPNhaKAwcaWJHMiNja2xMeHkGSOR6fTYWPjhMHGERu9kehYXwwGJ/ToMSXFYmd0Q683EhMXgNHopgZvzUnxmBUTOp0NRhtnEhLDU4J7SdgaXbHR26EAsXHJr1ExJ6LX26Yc6wKpp4PEpDgAEhLCQAeODgVIMEViNidiNsdjMDhia3RFMUNsfAA2ensUkpIvDkBKQFZPkjkeko8+CgpGgws6bDAY7IiJe4SdrQfxCSHJ/ck2N/EJj7GzdU8OLitmksyxan9M7RsO9gVITIzGlBiprqPTGbAzuhGX8DhNu/zEJ4SimE0oKOj1thhsHDGbE0hMigF0GI0uGGwciY0LQK8zgE6PndGd2PhA7IzuKf1bh15nxN4uuW+kbh/AybEw0TEPk8+/jZHExHgMNnbExCX3ubQMBsfk4Dm6lOOrYKO3Q6ezITEpBp3OgMHGAVNipPq+SQ4GJ58fvd4Ws9mEweCMrcGNmDhfy+3bOJNkjsHGxp7ExBgAHB0LEhvjj4N9fmLiAtTjbaN3wM7OHcWsEBsfgIO9J7FxyZ8DNjYO2OjsSEgMw9bWnYSEMOxt8xCXEJzyGhJwdSpFEc+2Fs+fll/gPkIjLuHoUJCYWD/s7fIRF58c8FUUM4qSBGne73q9AYONCwmmUBztPVPeH8nnFkg5F8lB1NRt2tnmIcH0GBSwVc8V6HUGzEqiuk7q9oxGN4w2jmnOjUI+j7rk9aiZ5pVnXXxCGPf81lK2eH8AoqOjMRqN2Nraapty/e4CShTuiq3RLeXfCyle+G3sjLm0TTMkaRmEEEIIkW22NnpsbeRrhRBCCPEqM9g4QEogxFJycBdApzOg1ycHKFJHGqal0z35PmD52HKbyf9OrtfrjGnKbdDrDBbt9TqblP8nl+l1Bmz0tug0IQ1dSrvUYHTyY4Na/qTdk22nvkablH23sXFIWR9Sh8JZW99Gb4tOp1fbkvp/zfFLfv7kfz/5f/Jzpt2u9vikSl4/dT+tPwfqPltux2CTPPpYbYPuyXOn/Je8jk2650/dfmqdTmeDjd5OPfdP2qQ/B0/2S2e5j+iTg90Z0B5nUkZ3A+j1tuiwAc2xSHtsn5z/J/0gfVnabST/22DjoJY9OY4pr0Wn02zHct0n5Wn7a/pzqT2+qay11evT9Jl0xzd5P7V9iQy2lVyffJEi1ZNz++Qcp62ztk+keY8YbBxT6pPbWXtea2xS+o72edMeR8tyvdrftOuQ0ifUxxb1yecwbX3qcXry2p+8Zu1+vjAvcSitZS8RQgghhBBCCCGEEEII8UqQ4K4QQgghhBBCCCGEEEK8giS4K4QQQohsWX3tAV7zd+A1fwerrz3QVgshhBBCCCGEyCES3BVCCCFEtgzfe4HbYVHcDoti+N4L2mohhBBCCCGEEDlEgrtCCCGEyJYYU6LVx0IIkV1TJ26gYa2JmV68By8mNiaB35cepmGtiRw5dF27SSGEEEKIfxUJ7gohhBAiW6Y0qoitjR5bGz1TGlXUVgshxHNzdXckX343i8XeIXn2a4PBJl1d7tzOaSesF0IIIYT413ulgruxsbHMmDGDPXv2aKtUN27coF27dtja2qLT6dDpdHh5eXHlyhVt0yxRFIUZM2ao2yxTpgw+Pj7aZlaZTCYGDhyorjtjxgwURQHgzJkzODk5odPp6N27t3bVbDl27Bh2dnbq865cuVLbJEPTp09X13ueZfr06Rbb6927NzqdDicnJ86cOWNR92/2+++/o9PpaNu2LXFxcRZ1ERERzJgxAy8vL/W42dra0rBhQ7Zt20ZSUpJF+8zy8fGhTJky6LLRp7K6jdjYWJYuXUrFihUt9qVdu3acOHFC7e8ZWb9+fbo+lNHSpEkToqKitJsAwN/fn08//ZRChQqp7T08POjbty/Xrz975M6pU6dwdHRM95zWlmLFiuHv76/dRIbCw8Np3rz5M/chrQcPHljdnwEDBnD37l1tcwsval+io6NZvHgxlStXxsbGRl2nePHijBs3LsP10vq3bCMyMpKGDRumO34ZLVu2bNFuAlL+npw4cYJ27drh7Oysti9UqBCffvrpM18HQFJSEkePHk23DWdnZ7p165ap9501J06c0BZlypDqXsR++jaxn77NkOpe2mohhHhu3h+3ZN2W4RZLn36NAGj+ZsV0dROnvqMGf4UQQggh/gtemeDu1atXqV69OqNGjSI2NlZbDcC5c+eoXbs227Ztw2QyqeV2dnbky5fPom1W6XQ6vL29adeuHQA3b95k4sSJFs9jjaIofPvttyxYsACALl26MHz4cHQ5PKRAURR+//13EhIS1LLFixdneOzEi6coinoholGjRtjb26t158+fp2LFiowaNYrbt2+r5SaTicOHD9OuXTtatmxJQECAWpcZJpOJkSNHcvPmTW1VpmV1Gw8fPqRRo0b07t2by5cvq+Umk4lt27ZRp04dPvzww6e+V7Ib8FcUhaVLl1KsWDG++eYb/Pz81LrQ0FAWL15MuXLlGDVq1FNfx40bN3LkPaIoCrNmzWLv3r3aKquSkpL4/vvvKV68uNX9+eWXXyhbtixLly7NMID3Ivbl/PnzlC9fnr59+3Lx4kXMZrNa5+Pjw9SpUylWrBg//fRThq/j37SNoKAgbt26pS3OktT3V506ddi2bRvR0dFqnZ+fH9988w3FihV76rkNDw+nW7du1K9fP902oqOjWbVqFXXq1KF79+6Eh4dbrPs0Pj4+9OzZU1ucaXqdDn0O/20TQgghhBBCCGHplQnubt68mWvXrmmLLSxatIiwsDAAunbtyo0bN/Dz82Pz5s24u7trm2eZg4MDc+fOpWTJkpASLF21apW2mYVTp07x1VdfAVC6dGlmzpyJ0WjUNnvhAgMD1UBS8eLFAdi/fz/nz5/XtHy2Fi1aMGDAgCwtlStX1m7mP+fx48ecPHkSg8FA/fr11XIfHx+6du3KgwfJM8p36NCBI0eO4Ofnx8aNG6latSoAe/fupVevXpka4Zlqzpw5rFmzRlucJVnZRnh4OH369OH06dOQsi/nz5/nwYMHrFixglKlSgEwb948Jk+ebDVYlZiYqI6qzZMnD3369EnXn9Iubdu2xWAwWGxj/fr19O3bF5PJhNFo5IsvvuDGjRvpXseMGTP49ttvrb4OQB3hb2try3vvvZfuudMu3bt3x8HBQbsJq7Zs2cK0adO0xVYpisLs2bP5+OOPMZvNuLq6Mn36dO7du8eNGzcYPnw4RqMRk8lE3759M7yTIbv7cvnyZVq0aGG1n168eBFvb2/1dXh7e7N+/XqL9f9t2yAl+BocHAyZ/FwsUqSIxfqpF/tmz54NgKurKz/88EOG58EasHkAAP/0SURBVNbayN/Y2Fh69OjB2rVrAShVqhQrVqzgwYMH3Lt3j3nz5lG0aFEAVq1axciRI596QSPVqVOnaNiwYaYv6gghxKvCbFbYt+cy/+vyAw1rTaRxnUl8Nvx3Htx/bNHu+jU/WjSaypJfD7LityM0rfcljetMYtrkjcTFJg+WeBwcyQ/f7aJ1s69pWGsiTet9ycRxawkKjLDYlhBCCCHEy6RTMopy/MNMnz6dzz//HFICve3bt7eoj46OpmPHjuzdu5dChQpx/PhxChcubNHmRVm3bh1du3bFbDbj7u7Ovn37qFatmrYZPj4+tGjRgps3b2Jra8uePXto2LChRZszZ87QqFEjYmJi6NWrF0uWLLGof14rV66ke/fuGAwGpk+fzvjx44mJicHb25s5c+Zom6fzrOMtnm3Pnj28+eablC9fngMHDpAnTx4URcHb25sff/wRUo7zyJEjLUZyh4eH07lzZzU4v2LFCrp166bWZ+TkyZO0bNlSvcABZLlPZXUbafuJtX0JCQnhnXfeYf/+/Tg6OnLw4EFq1KiRZgsQHBxM48aNuXLlCp07d2blypXpgrdPExgYyBtvvMHly5dxd3dn165d1KpVy6JNSEgIHTt25PDhw7i7u3Po0CEqVrTMCxoXF0fnzp3Zvn07r7/+On/88Qe5cuWyaPM80n4OpGrcuDFbt27F2dnZoi2az4TSpUuzfft2vLwsb3NP+xnUsmVLNmzYYBGcze6+JCYm0rNnTzWVi7VzC3D06FHatm1LWFgYhQoV4tChQ5QoUeJft41U8+fPZ/DgwRgMBvbv30+DBg0s6p/l0qVLNGzYkLCwMEqXLs3u3bspVqyYRZu057Z27drs2rULNzc3tT71sx2gWbNmrFu3zqIezfuOZ3yGxMbGMnfuXMaNG6cGgV/E14LQ0FBtkRDiP0RRErly+ydt8Qvz+9LDzPthD63aVGHsxLe01ZCmjUduZ0JDoilbzpNSpfNz/qwPvg9DcHN3ZO68PpQomXx33/Vrfnw0cBFGo4Ho6HjqNSgNwGv/z959hzlRbg8c/86kJ9sLu0tbegcVBREFUURAUdQLilcRFUWuCurFCvoTFLEgooIF1CvYGwoIItdOk6IUqdKXskvfmk12U+b3R5IhyWaXsotX5XyeJw/LzOTNzDuTSXLmzHlb12XgLV3YsnkfD4/4kAP7C6lTN4Uz22ezd08+q1fuJC7eygsvD6Rl6zpRa/G/l5LYjqz0btGThahUUck2du+bS5y9HiWlgQvjAIqiAgqa5sNuq4tfK8ftPoDVko677GBEG1ZzKu7ywAUUiymZMk8+RoMDszmZUtceAGyWWqiqmbLyAry+EiymFHx+F16fC4etDgoqGhrO4PJxtjqUuPYSZ6tPiWsXcbZ6+PxleH1ufL5S/JoXo8GBxZKEs3QvRqMDr9cF+HHY6gIaTtdebJYMfD4XRmMCpe7dgIJqsOD3lWGzZuBy78dmzcTjKcTrcwXmq0Zsllo4XXsxG+MxGO2oigm/vwyPtwS/vxy/5kNVTaiqBVUxYDbFU1K6B6s5FYPBBgo4S/dgt9Wl1LUHiymFMs8R7NYsSt15WC21MKgWnK7dWMzJaBoYDWZ8fi9Go52yskN4fS4URcVgsGM02DEZbBSX5mA0OlBR8frdmI0JGAw2nK49mEyJmAx2PF4nPp8Lv+ZBVYyYjPGUefIxqBY0zYfJFI/RYEdDodS1B5s1E83vxWiwU+LahcNWBxQFNA2/34uGhtt9ABSIs9fFXV4Y6AN/GUajPbD+fih178Wg2kDR8PkCJQoNqgVVNeHxBhKYFBQ0NEzGeBQMmEx2nK5cLOYUysqPAOCw18dZugubtRY+bxl+fPh8TkBF03w4bHVxugJ96/WWUO4pwGqphbvsAKpqwmJKxlV2AIetDk7XXuy22rjLj6D5ytDQMBhsWExJlHuL8XpLUBQVozEOkzGeUtdeVNWEoqhYTCmBfWVOCR7fCgbVitWSitO1V18PgHhHA4qdO7GaU1ENJrzeMkxGB05X4JgLZzYl4POXo6Dg9bkBDZPRARjweItQFTMGgwWPt1h/3/j85cHjZRcGgxWfrwyjMQ6rKZUSV2T5UJMxAa/PidHowOMJXIyMc9TH6dyN3VYHp2uP/j42qDZs1jT8fj+l7r3YbfUodQXOAwaDDbMxCVdZHnZbbUpdudgsGbjK9mM02PH6SklwNKZe1uURrx9u/6FFHCpYGbZPs3C58zAYrGiaH78/cDE1tK8MBjMmYzLusv04bPVwunbr+5bgecQV/DvOXp+S0l1YLRmUlR8ETcNqqYWrbD8AhmC/Oay1cbpz9fZMpkTMxoSwfaNRK+U80lM6hK35iSsrL2Dn3s9o3vA2CMYsTSYTZnPF0lG/73iDhnWvxWwK/K77fcdbNKh7DRbT8f+G/8tk7h6Lpml4vYERuhs2bFjhx25N6tu3L0OGDAGgoKCABx54oMKtrx6Ph9GjR+sBnSeffPKEAwEny+128+677wLQrFkz+vXrpweVv/nmG/bvDxzc4tT64Ycf0DSN7t27k5aWBsHjJVTTskWLFnot4nCJiYk8+eST+pv+yy+/PGaw5fDhw9x9990UFBTQsmXLEwqOhpxoG/v372fatGkAdOjQgcGDB1fYlpSUFMaPH4/dbqe0tJS33367wrbs3btXz6js0KHDMV832tKlS/VyEPfcc0+FwC7B9Rg6dCgE98F///vf6EUoLCzU7w5o06ZNjWT7h5e4aNy4MXa7PXqRCF6vl+eff57S0lLMZjNvvfVWhcAuwXPQtddeC8DChQsr1BSv7rZs2bJF76OePXsybNiwCvsWoHPnztxzzz0Q3I9LlizR5/2d2ghZsWIFAHXq1KkQ+D0ea9as0S+cDB8+vEJgl6h9+9tvv7F582Z9Xvi5PSkpiRdffDHmZ11KSgpPPfWUfg757LPP9M/HEJ/Px6xZs2jbtu0xy5Ucj4827qbh61/R8PWv+Gjj0R+CQgjxv1bqLOP5l2/kjelDePjRvrz3yV307nMmhQWl/PBdxTE5Skrc/N+T1/D089fz9PPXM/CWLrhKy3n5ha85sL+QW27vxnuf3MXDj/Zl0us3M/aZayl1lvH0k7PIzz9aJkcIIYQQ4o/ytwnuhgsNlHOqGI1Gxo4dS4cOgUj+d999F1GnMXTrbSjw9UfV2Q3ZunWrHpjo2LEj9erV02sFb9q0Sc/mEqdOSUkJixcvBuCSSy7Rp7vdbg4cCFxZysjIwOFw6PPCNWjQgMzMTAgGmsJrakbzeDw88sgjrFixgqZNm/LEE0/EvBpUlZNpY+3atXo5hR49epCamhq9CACtW7emW7dAxsh3332nb3/Itm3bKC4uRlGUClm9x6OsrIzu3buTmZkZ0dfRWrRooQdXYwWydu3apQ9k1alTp2q/X0PngU8//ZSkpCSef/55Pchfmd27d7Nw4UIArrnmGs4777zoRSB4Dgq9pxMTEytkSVZ3WzZu3MiRI4Er5QMHDqxQsiHcJZdcogfk9+wJXKnmb9YGwcHUQsd769atKz3eqxLeZqzALsF9GyrN4nK5IupuHzp0iDVr1kAwa7dFixb6vGgtWrTQy+McOnSowoCOX331FVdddZVe89tkMvHyyy9XuLvkeP37+zXsLHSys9DJv78PrKMQQvwZXHbFWXToGCipBmA0Gri0VzsURSFvb8W7DLIbpHFOh6PLA/y2ZhdrVuXQqk1drru+E0ajQZ/X9aKWXPWPDuzYfoCVK6oe7FQIIYQQ4lT40wd3n332WRRF0W/9JlgvUQmO8D5//nwcDgfx8fH89NNPAPz000/Ex8fryxzPyOMnKjU1lcmTJ+sZcaNHj2bRokUALFq0iMceewz+4Dq7IZ999hkFBQUoisJ1112Hoih0796dlJQU+B8NrBbKUHU4HJUOnuVyuXjnnXdo06YNSnDk9+TkZO666y727NlDXl4e2dnZKIrCs88+G/10CBuFfsCAAaSkpOjtZGRkMHToUD04E0toHbt160ZJSQl5eXmMGDGCjIwMvZ0GDRrw+OOPU1RUdW213bt3s3btWjIzMyNu/zeZTPqt+E6nM2aQkWB2aajWblpaWsRgbNFmz57NW2+9haqqvPLKKyeVUXgybfz6669omoaiKFx00UXRs3VWq1WvObx169YKNT1Dx0NGRoZez/pE9O/fn2+//Za8vLwqs+M3bdpEaWkpBPdDtNAAZEajkdatW0fPPmHh9baffvppzj333OhFKti0aRN79+4FoE+fPlVmMd9www1omkZubm6FoHZ1t+XAgQPUq1cPk8lEQkJC9OwIbrcbn88XPflv1QZRg6mdeeaZVb4nKxOeQV3VOSR0DBiNxohyGvn5+aSkpJCcnIzD4ajy+PD5fPrxfixXX30127dv55ZbbkFVT+5rgdt7tN/C/xZCiP+12nWSUaIucsYnWLFYjLjcHny+owNsAmTVTsJmj7zAvWZ1DpqmcUGX5jjiIs//iqLQ/pzA96Zflm+PmCeEEEII8UfQf8WtX7+e9PR0PYhV2aNp06b6LdSnu44dO/Loo48CUF5ezogRI1izZg1DhgyhvLxcv626sgytU+HQoUP6YFgtW7bknHPOgWCQ+dJLL4VqDKx2Km3cuJH27dszaNAg/RZ7gkHOV199lebNm8ccXCic0+nk1ltvpVOnTnz88ccR2YwHDhxgypQptGjRgrFjx1YaVA2ZM2cOTZo04YUXXojINM3JyeGJJ56gcePG+iBisSxYsIAjR47Qvn17ateurU9PTU3VA6G//vprzPIAmqbxwQcf6NmG4dmE0bZu3cqwYcPw+/2MHTu2QpDveJxsG6FSAHFxcaSnp0fPjhDKMPR6vRH96Xa79WMxMzOT2bNn06lTJ+Li4ioE1A8fjhz45EQcOXKE119/HYIBttB7IVyoXEZycjIrV66kV69eFS4Q3HXXXezYceysnPASFzfffDODBw+OXiSmVatWARAfH1+hJvCJqO62DB06lF27dlFeXn7MetuLFi3S71oID17+ndoA2L59u17Sxmg0MnToUP1ik6IomM1mevXqxffff19pgPj888/X250yZUqFcj4E34+hwdKaN29Os2bN9Hlt27blt99+48iRI5XWwQ7ZsmWLHoyOFQg2GAz06tWL9evX8/nnn1e7Rv1TXdtgNxmxm4w81fXkj10hhKhpdesFkhtiKSwopbwssmxNYqIdsznynHlwf+CC3PJl23hu3JcVHvO/WoOiKOzZcwRXaaBeoBBCCCHEH0UP7rZq1UqvSVmVESNGVBgB/FS66667yM3N1YOoANOnTyc3N5cVK1Zw4YUXsn37drZs2UKnTp0geBvyli1b9GVq1QoMlHAqDB8+nP79+0MwU+/iiy/Wa13+kXV2Q3755Rc2btwIwYzG0G3gRqNRH4SnvLyc999/v0Lt0/+VnJwc+vbtq/dbjx49WLBgAbm5uSxYsIAePXpQWlqqB2liKSkpoX///nopjPr16/P666+zc+fOCiPIP/bYY0ycOLHS7V+6dCkDBw7E4/EwbNgw1q5dS25uLjNnztRvlz506BD333+/nl0bzuv18s0330Awyzw8w09RFB544AGaNm2KpmkMHDiQyZMn60Ge/fv3M3z4cMaOHQvBW69D+y1aYWEhQ4cOJS8vj+7du3PnnXeinODt9yfbhtPp1LMLk5OTj/keC8+UDc/cDa8Nu3r1au677z6WLVsWUYYiFFBv1KgR8+bN06cfj8OHDzNt2jTOOussPbP+kUceqZDN6nQ69WD1wYMHueeee5g/f36FCwShCw3hZViiRZe4GD16dMxM4VhCmeWhPg1ls7dr104vNxMXF8ctt9xSaRZ6TW7LsezcuVOvAZuSkkLXrl2jFzmmv0obv/32m95PTzzxBFOmTIk4H3k8HubPn0/37t254YYbYgZuW7duzSOPPALBWslXXnkla9euxefz6c/v1asXeXl5qKrKmDFjjvneisXj8TB16lTKywMBhujzEMBll13GvHnzaNWqVcT0k/Wvsxrj/PfVOP99Nf86q3H0bCGE+FtYvXInX878tcJj4U+bTvqzVAghhBCiuvTgrqIoDBkyJObgPSEdOnTQA5l/lLi4OLKysiJGlU9OTiYrK4tatWphtVrJyMggMzMTi8UCgMViITMzU1/GYDhaF6ummUwmxo8fT9OmgRF1Q9mWf3SdXYJBxenTp6NpGmazmZ49e0bMP/fcc/UMypkzZ7Jz586I+ZUJlcE43segQYOim6iUFqxLGgr4jRs3jnnz5tGlSxeysrLo0qUL8+bNY9y4cVV+aX7vvff0wN9FF13EqlWruOOOO8jOziY7O5s77riDVatW6Vmzjz32mB7si1ZWVkZcXBw//vgjL7/8Mm3atCErK4u+ffvy448/6jUpFy9eHDMDev/+/fzyyy/Y7XY9czpcdnY233//Pf3799cDyElJSSiKQmZmJpMnT0ZVVe655x5mzZoVc8AkTdOYMmUK3333HUlJSYwfPz7mclWpThta2ACGxyMzMzNmrdM9e/ZE1BStX78+r7zyCps3byY3N5d58+bRo0cPCN7G3qdPH2bMmBHWQmyLFi3CZDKRlpbGLbfcwq5du0hPT+ejjz7igQceqPC+LCoqigg6JyQk8OSTT+qB/cWLFzNw4EAIBs7uuusuxo8fH/OYDC9x8fzzzx935r7b7ebQoUMA1KpVi3379tG1a1cGDRrE2rVr8fsDt406nU6mTZtGq1atmDx5coV1qMltqYrH42HcuHFs3x64BXXgwIFV1oCN5a/URqjWbcjAgQP55ptvyM3NZfPmzTz11FN62YePP/6Yf/zjHxUCvIqiMGLECD744APS09NZsGAB7dq1w2g06pm/27Zto0WLFnz//ff84x//iHj+8Zo9ezbvvPMOBAPKV18deyR5IYQQxyc9I3B+H/vsdSxcPrrSx6TXb65Q0kEIIYQQ4lSLKK5Xr149RowYET5JpwTr3p7MIDJ/d9nZ2TzwwAP6/w0GA3fcccdxZ+vVlPBR4S+66CI9yzQkIyODa665BoI1HefOnRsx/39h586d+i3IF154IcOGDasQjDcYDNx77736AFLRDh06xKRJkwDIyspi6tSpen3hcCkpKYwfPx673U55eTn/+c9/Kg1o3XPPPXTu3Dl6MomJidxwww0QDKZHD2RFcKCxXbt20apVq0pryBYVFVFcXBw9WRfqg8pu7w6v6zxp0iTOOuus6EWOqSbaOF6qqlYIqAKUlpaSmZmJ2WzmvvvuY9OmTdx55500bdqUrKwsevXqxfz585k0aRKqquL3+3n88ccrDMoW7cCBAxWCzwcPHmTChAn6YIPhiouLSU1NxeFwcM0117Bjxw4effRRPbDfuXNn3nnnHebMmaMPyjZ+/Hg9QzYkvMTFyJEjj1kGIJzX69Uzlvfu3UuvXr345Zdf6NChA/PmzSM3N5e1a9cybNgwTCYTfr+fYcOG8cILL0S0U1PbUhWPx8NTTz3FG2+8AcEA4siRI2Pu48r8ldpwu90oikJ6ejrZ2dksX76cd955h0suuYSsrCyaNm3KyJEjWb9+vX5B57vvvuPDDz/U2wgpLy/Xg/iVMRgM+Hy+Ss9PVfn222+57bbb8Pv9qKrKxIkTTyr7VwghxFGtWgdK16z8ZcdJnZuFEEIIIU6lCiOn9O/fnw4dOkRP5vLLL6d3797Rk0XwtvHx48fr//f5fIwZM6ZC1tap9t133+mZw1dffXXMTMk+ffpgNgcyCt57773jWscePXpw++23H/cj+nbmqqxatUq/vf/666+PyNAOZ7PZGDBgQPRkCA5CtXnzZgB69uxJ48aV3xLcvHlz/fhesGBBzCChoihVltMIr6EbKyD21VdfoWka5513XsRgSAQzXqdNm0bbtm35+uuvMZlMEaUfvvnmG66++mo8Hg8vvfQSbdq0qZAdnJOTw+DBgykvL+f222/nuuuui5h/PGqijZrQpUsXtm3bRllZGS+88ELMY1ZRFO644w49i3H9+vV62YvKtGzZkmXLllXIVF2xYgVdu3atkP3brFkzVq5cSUlJCTNmzIh5cYDgrez//ve/IXhR4YMPPtDnRZe4uP/++08oyBguLy+PAwcO8Pjjj7N48WJ69epFVlYWbdq04eWXX+bHH3/Ua7dOmDBBr61KDW1LVUJlJ8aMGQNAo0aN+PTTT08ogPhXa8NqtTJt2jQOHDjAzp07Y35GAtStW5cXXnhBP8e+9dZbEReA9u3bR8+ePRk+fDgHDx7kzDPP5JNPPmH37t1s3ryZl156ifT0dNavX0/37t154IEH8ByjPni4uXPn0rdvXwoKClBVlbfeeuu4a2gLIYSoXNsz6tGyVR1mzljB99+ujwjwlpV5eHbsbC7sNIa33/gx4nlCCCH+OCf72+t0pCB99XdTIbibmprKww8/HPHGsNvtPProozEDL6c7l8vF3XffzZYtW1AUBYfDAcF6ig8//HCF7MFTpbCwkPfeew+AOnXqxBwwiuAo76HSBMuXL+fnn3+OXqSCYcOGMXXq1ON+HO/gUQQDswSDt+3bt4+eHaFFixZ6pmG49evX6/1ct25d9u3bR15eXsxHcXGxPvjXvn372LNnT1RrgXWJDsoer/z8fH7++WcUReHKK6+Mns3KlSu566678Pv9NGrUiDVr1kSUfrjkkkuYMWMGH330Eaqqsnv3boYMGaIPJhbKMtyyZQtNmzZl1KhRJ5whXhNtnCiXy6XX/zwZJpOJIUOG6OelWAPRhWvZsiUdO3bUM1WnT5+u96nf7+eee+6pdDCxqiiKwj//+U89YLp48WJKSkoqlLgYN27ccZe4qMyFF17I/fffH3PfdO7cmSeffBKCgeDoYPXxqGxbqhIatHDChAkQDIbOmTOHli1bRi9aqb9TG7F06NBBD6hu2LBBL/Xg9Xp59NFHWbhwIQCPP/44y5cvp3///tStW5emTZsyfPhwNm3apJ+jJ0yYwMcffxzWemyhi0ZXXnklpaWlemB30KBBf9iX3PfX51D/tbnUf20u76/PiZ4thBB/aYmJdv790OXEJ9gYPeozbrz2FZ4ZO4vHR37KVZdNYM7slWQ3SKd3n6N3zf2+KZceXZ+i35UTOXzo6N1aJzpdCCGEEOJYKgR3AXr37s3ll1+u///222+nY8eOEcuIwA/qSZMmMWfOHAD69evHt99+q2fUTZ06lVmzZkU969RYsWIFy5cvh+At3Y0aNUKJUQ/Xbrczf/58CK7/9OnT/7AAdCyhQaEURUFVYx6Outq1a+sDxIUrKCjQ/x47diy1a9eu8hEqA1FaWhozc7c6tm/fzoYNG6hfv37MINH06dMpLS1FURTGjx8fcxlFUbj22msZOXIkBPdtKFP1448/5o033kBVVV555ZXjrucaribaUBQFozFyJOmq5OfnV/s4q1Onjh50z8nJOWYgMlyoT//1r39B8D1yrABxZdLT08nMzARgx44dFBcXR5S4ePrpp0/qfGk0GvWLQxwjk53gYHuhwGz4YF8nIta2VGbfvn307t1bv4jUunVr5s+fH/MYrszfqY3KWK1W/e6B0tJScnNzIVg254svvoBjBO5TUlKYOnUqWVlZABEDLsbi8XgYM2YMt9xyC36/H5PJxEcfffSHBnYBRvywht1FpewuKmXED5H1iYUQ4u+gRcvavP3eUPr0bc+BA0XMnb2K779dT1yclX8N68Frbw0mMyvwG0AIIYQQ4o8UM5pms9kYPXo0drudrKws7r777j/0R+JfRXhAJysri3HjxtGpUyeefvppAPx+Pw8++OBxD1x2sjRNY+bMmScV3Pn666/17NnTUXUDjtF+/vlnSktLOeecc8jIyIiYV1JSwm+//QbBQHVVAUAlmPkbypb/9ttvKSkp4c0334TgsXXppZdWCN4risI555xDaWkpAO+8844+fdCgQTXSBoDD4aBOnToQDNweK0heVFSk/x0afPBEJSUlVRnsPBZFUSIuWq1YsSJi/vGyWq169nfIm2++qWcm/+tf/6rQn4qiULt2bXbt2gXATz/9RHx8PIqi0K1bN0pKSrBarREXL8LLf8QS3h979+7V6/WeiFjbEsuKFSvo2LGjnnV66aWXsmDBgioH4Iz2d2rjWELvjXBbt27Vy+ZceumlVR7LjRs31rN3N2zYUOlnyJEjRxgwYIBeViItLY2ffvqJ/v37o/zBn9ke39HPn/C/hRDiVLjhpgtYuHw0o0ZXPmBkaJnzuzSPnkXzFrX5ZsGoiAHQQtOqajO9VgIPjbqSb34aqQ+i9umse/nnwPNxOAIDO4eE2vts9n2kpsWf9HQhhBBCiGOJGdwFaN++Pbfffjv33HNPjf7o/bsIr1mqqiqTJk3S+2nw4MH0798fgpmcQ4YMOaEswxO1c+dOZs6cCUD9+vV5++23+fTTT6t8hAIHBQUFeibr/0LduoEBKsrLy3G5XNGzI+Tm5sYciCg8+2327NlomnbcjxMZ8OpY3G438+bNg2Cd4qoyWxs2bHjM2/YzMjL0TNUTqbv5RznjjDMgGLQ+ePBg9OwIodrENpuNevXqRcwrKCjgwIEDlQ4eFxK+/9PS0rBardGLHFNaWpoeMI/u05KSEvbv319herTi4mK9pEOtWrVOaj0qE+rTE2U0GiOCeTW5LXPnzqVbt27s3r0bgDvuuIOZM2dWWss3lr9DGx6Ph/379x/XuTx0vBuNxpglXlq0aBE9KYKiKDRr1gyCNdxjXYTas2cPPXv25PPPP4dg9vHPP//MeeedF73oH+Lpbm2JNxuJNxt5ulvb6NlCCCGEEEIIIU6RSoO7iqLw5JNPMnz48OhZpz2Px8MDDzzAli1bABgyZAh9+/bV55tMJsaPH69nKH7zzTe8+uqrJ5VZezyWLFmiD0p2xRVXMGjQIPr161fl48EHH9SDQR9//DH79++PavWP0apVKwhm0K5fvz56doS9e/fq2aTh2rVrp2/L0qVLo2f/YXJzc1m5ciUpKSnHHFRux44dVd5qDbB//359MCaTyYTD4WDGjBnk5uZW+fj666/1AGa/fv306S+99FKNtBES6ndN0/j111/16dFcLpc+Pysri/r160Ow5ukll1xCcnIyzZs3Z+PGjVHPjLRp0yZ9/3fo0AGj0Yjb7eauu+6ibt26tGjRQn8fVGbz5s36RYTmzY9m8tx2223Ex8eTlZXFd999F/aMirZt20Zu8Fb7Nm3akJSUxEsvvVShD6Mfq1ev1rNxO3XqxJYtW8jNzWXGjBl6OYaOHTvqFwWOdSznhgW7GzRooLdR3W0J9/nnn3P11VfrdVxffPFFXnnllROqv/53aGPevHmYzWYyMzN5+OGHqzyX5+fn68HdOnXq0LBhw+hFjnm3hKZp+iCRBoOhwoWinJwcevXqxS+//ALANddcU+PZxydqyBmNKLrvaoruu5ohZzSKni2EEEIIIYQQ4hSpNLgLEB8ff1w/fE8nmqYxceJEPv30UwgGmcaOHVvhx3d2djZPP/20Xkf2scceY9GiRRHL1ASXy8W7774LwSyxAQMGHNftuOedd55eFmDTpk388MMP0Yv8ITp37qzfwvzFF19Umr3r8Xj0+pjR2rZtqwfqPv/88yoD1YcPH6Zjx46kpKTQqVMnPYOvJqxbt459+/bRtm3bCtmpAHFxcbRr1w6Cgeoff6x8RGVN05g9e7beHxdddBGKopCamkpWVlaVj7S0NP0YCJVWycrKIikpqUbaCDnzzDP1GqUzZszQB32LtmHDBv0W+N69e1OrVi0IlnY499xzIZi9O3fu3IjnhSssLGTKlCkAmM1mPXhutVoxGAzs3buX33//vcrj2OVy8dFHH0Hw4lV4WYxQJrumaXz22WcxMyUJHodvvfWWPr9nz54oikJSUlKFPox+1KpVSz9PWCwWMjMzycrKIjU1Ve/rtm3bcvbZZwPwySefVHp8asFSLKFgd2gAL2pgW0KWL1/O4MGD8Xg8qKrKBx98wPDhwzEYDGGtVO3v0kbz5s312sTz58+PORBjyH//+1/9YsZ5552nl2dp0qSJniE8c+bMSt8vBIPuoWO5efPm+gURgu+FwYMH6xfDBg8ezHvvvXdc2cdCCCGEEEIIIf5+qgzu/plEZ5T9r4TX2U1KSmLy5MmkpqZGLwZA3759GTJkCATLDgwZMqTKoMDJWL16tR4EOPvss2nb9vhuh01MTOTqq4/WFJs2bVqlgdVTqUGDBvTr1w+CQZMXX3yxwu35mqbx/vvvM2PGjIjpIRkZGVx33XUQDFTff//9MeuP+nw+pk6dyooVK8jPz6dhw4ZkBQctqi5N0/TBks4///xK62kOGjQIu90OwMiRI1m9enX0IgB89dVXvPDCCxC8hfvCCy+MXuR/Li0tjZtuugmCdUwnTZpUoQzAkSNHeOCBBygtLcVsNnPDDTdEBBCvueYavT+eeeYZlixZEvbsgFCmfChAfOutt9KhQwd9fv/+/TGbA/XynnjiCbZu3arPC/H5fLz44ov64IeXX345559/vj6/a9euetbj22+/zYwZMypkZ4Yu7EybNg2CpTf69OkTsUx1JSYmcvPNN0OwRuvIkSNjHsvfffcdkydPhuAFph49eujzamJbDh8+zN13301BQQGqqvLJJ59w3XXXHdeFo5C/Uxv169fX6zVv3bqVMWPGxDxfLl++nKFDh6JpGklJSTz44IN6QL9p06b6OXfFihU8/fTTFd4vBIO3jzzyCHl5eQAMGDBAL+2gaRrPP/+8npF9++2389prr8lFWCGEEEIIIYQ4jf1lgrvhgwu99957bNmy5bjqdNak8Dq7AI8++miVg2IZjUbGjh2rB6I2bdrEE088EfMHPcCCBQsYMmTIcT3eeustAObMmaOvz9VXX33MOq7hLr/8cj1o/sMPP1QaaJw0aVKF1z/W4+GHH6agoCC6qQoUReG+++7TS1iMHDmS3r17s3DhQvLy8liyZAlXXXWVPhJ8Ze677z66d+8OweOjVatWTJkyhZycHPbs2cOXX37Jeeedx8iRIyEYmHz00UcrZFyfrMOHD7N8+XIURdEzJ2Np3749jz/+OAC7d++mY8eODB8+nOXLl+vbe9NNN9GnTx/99vGxY8fqtYn/bIYOHar3+5gxY+jXrx9Llixhz549fPTRR3Ts2FG/+PDkk0/SqVOniOeH90dBQQHdunVj+PDhrFu3jj179vDpp5/SsWNH3njjDQjWFX344Ycj9tsFF1zA3XffDcCWLVs4++yzGTduHFu2bIm575s2bcrkyZMjAmL16tXjmWeeQVVV/H4/AwYMYNCgQSxZsoS8vDy+/vprevbsyUMPPQTB42fcuHGVBvGrY/Dgwdx+++0Q41het24dw4YNo2fPnnqw8qmnnoq4wFQT2/Lhhx/qA84lJCTw5ZdfVniPx3qEzkt/tzaMRiOPPfaYfp5666236Ny5M59++il79uxh3bp1DB8+nAsuuEA/7z366KOceeaZlbYxYcIEOnbsyHvvvUdOTg45OTlMmTKFdu3a6XXQu3fvzm233aa3sX79ej2orygKeXl53HXXXRXWPfpxvOdjIYQQQgghxJ+fwvEnqojTg6JFp3T9Se3YsYMuXbpE1NSMj49n4cKFnHHGGZSUlNCnTx9++uknLrzwQubMmVOjgRePx8MNN9ygl2Po378/77//fsRgXpVZtmwZF198sX4L9bvvvsuNN94IwK+//krXrl1j1pKtyk033cRzzz1Ht27d2LRpE0lJSSxcuJA2bdpEL1opr9fLwIED9VvV77rrLiZNmoSiKDz77LM8/PDD0U85bvXr12fp0qV6ZuygQYN45513sNvtLFiwQL/1PGTjxo1cc801ldaiTEtLY8iQIYwbNw6CWZ6h4FTIkSNHGDp0qL6PKtO4cWM+++yziMALx7GOIV9++SVXXnklhK3HokWLuOiii2jWrBk//fQTaWlp0U/T+Xw+XnnlFe6///5KA/0EA1EfffQRvXv3jp5VpfBj6qabbmL69OnRixzTibSxZ88err76ar3+ZyxDhw7l5Zdfjvl+8Xg8PPvss3pGfGW6d+/Oe++9p98eH+542zjvvPP44IMPaNCgQfQsNE3j3Xff5bbbbqtyv7Ro0YLPP/9cL0lxvPLy8ujUqRO7du065jnK6XRy991365m1sdjtdj755BM9ozRcdbalsLCQnj17smzZsohlj0foWPk7tRHuWOcpgvWx33zzTQYOHBgzO3jnzp3885//5Oeff46eFeG6665jypQpERfsRo8ezZgxYyKWOx7R5+PKhD5HqyoZU5l31+fw0I+/AfBst3b0qZ0QvYgQ4jSiaV42bHs1erL4H0hJbEdWerfoyUJUqqhkG7v3zSXOXo+S0qMlwhRFBRQ0zYfdVhe/Vo7bfQCrJR13WeTgylZzKu7yQAkqiymZMk8+RoMDszmZUlfgTlabpRaqaqasvACvrwSLKQWf34XX58Jhq4OCioaGM7h8nK0OJa69xNnqU+LaRZytHj5/GV6fG5+vFL/mxWhwYLEk4Szdi9HowOt1AX4ctrqAhtO1F5slA5/PhdGYQKl7N6CgGiz4fWXYrBm43PuxWTPxeArx+lyB+aoRm6UWTtdezMZ4DEY7qmLC7y/D4y3B7y/Hr/lQVROqakFVDJhN8ZSU7sFqTsVgsIECztI92G11KXXtwWJKocxzBLs1i1J3HlZLLQyqBadrNxZzMpoGRoMZn9+L0WinrOwQXp8LRVExGOwYDXZMBhvFpTkYjQ5UVLx+N2ZjAgaDDadrDyZTIiaDHY/Xic/nwq95UBUjJmM8ZZ58DKoFTfNhMsVjNNjRUCh17cFmzUTzezEa7JS4duGw1QFFAU3D7/eioeF2HwAF4ux1cZcXBvrAX4bRaA+svx9K3XsxqDZQNHw+NwAG1YKqmvB4A4MUKyhoaJiM8SgYMJnsOF25WMwplJUfAcBhr4+zdBc2ay183jL8+PD5nICKpvlw2OridO3BYa+Hx1NMuacAq6UW7rIDqKoJiykZV9kBHLY6OF17sdtq4y4/guYrQ0PDYLBhMSVR7i3G6y1BUVSMxjhMxnhKXXtRVROKomIxpQT2lTkleHwrGFQrVksqTtdefT0A4h0NKHbuxGpORTWY8HrLMBkdOF2BYy6c2ZSAz1+OgoLX5wY0TEYHYMDjLUJVzBgMFjzeYv194/d7UFUzTtcuDAYrPl8ZRmMcVlMqJa6ciPZNxgS8PidGowOPpwiAeEc2Jc5d2G11cLr26O9jg2rDZk3D7/dT6t6L3VaPUlfgPGAw2DAbk3CV5WG31abUlYvNkoGrbD9Ggx2vr5QER2PqZVX8fRqy/9AiDhWsDNunWbjceRgMVjTNj98fSFwM7SuDwYzJmIy7bD8OWz2crt36viV4HnEF/46z16ekdBdWSwZl5QdB07BaauEqC5TtNKhmfP5yHNbaON25ensmUyJmY0LYvtGolXIe6SlH7xY+GWXlBezc+xnNGwaSdZxOJyaTSb/rONzvO96gYd1rMZsCv/1+3/EWDepeg8VUcXDuyvxlMncbNmzIf//7X7p06aJPKy4u1gcDOpWi6+w2bdqU8ePHxwxUxdKxY0c9OxFg2LBhrFq1KmKZk/HDDz/oQYbOnTuf8GA6RqORQYMG6QGImTNnsnPnzujF/hAtW7bkl19+YdKkSbRu3VqfXqtWLf7973+zcePGiKzPWH2fkpLCxx9/zNKlS7nuuusiRqk3mUyce+65vP/++6xdu7ZCYLe6Fi9ejNfrpWPHjpWW6QgxGAwMHz6cnJwcRo0aRePGjfV5qqrStm1bnn/+eXbv3n3Cgd3/hbp167JgwQKmT58ese9MJhOXX345S5cu5dVXX425zwgu9+ijj7Jr1y7+/e9/R2Tpm0wmevbsyZw5c5g/f37MwC5hbWzatIk77rij0jYWLlwYM7BLMBPypptuIi8vj2effbbCfunYsSPvv/8+K1euPOHA7olyOBz85z//Yc2aNVx33XX6YGkE63n/3//9H7t27YoZ2KWa21JaWqqXBDhZf6c2wrVs2ZLffvuN2bNnc8EFF0Qc09nZ2YwaNYqcnBxuuummmIFdgqVoFi5cyJw5c+jZs2fEvk1OTua6665j6dKlfPjhhxXuxNixY0fE//9M7v9+DXklbvJK3Nz//Zro2UIIIYQQQgghTpG/TOauEOEZs7Nnz+aKK66IXkQIIcT/QNrLszjsClxpT7WZ2TIwMOihEOL0JJm7fx6SuStOlGTuSuauZO5K5q5k7krmrhAnZM2aNdSuXZsWLVpUWU5B0zQWLVoEwXIc4aPHCyGE+N965sJ2JFlMJFlMPHNhu+jZQgghhBBCCCFOEQnuiv+pOnXqkJyczO+//8748eM5fDhwhTfazz//zNSpUyE4CFf4beZCCCH+t247oyH5915F/r1XcdsZDaNnCyGEEEIIIYQ4RSS4K/6nUlNTueaaawBYsWIF7du3Z8qUKeTk5JCXl8fy5csZPnw43bp1o6CgAFVVGTVqVKUDUQkhhBBCCCGEEEIIcbr42wd3n332WRRFqfZj0KBB0U2LGqAoCg8//DA333wzALt27WLo0KE0aNCA2rVrc+655zJp0iQ8Hg8JCQnMmTOHHj16RDcjhBBCCCGEEEIIIcRp528f3BV/fg6Hg//85z8sXbqU6667juTko0WjVVWlbdu2PP/88+zevZvevXtHPFcIIYQQQgghhBBCiNPV3z64+9BDD6FpWrUf06dPj25a1CBFUTj33HP56KOPOHLkiN7vPp+P3377jREjRpCQkBD9NCGEEH8C09buJGPSbDImzWba2p3Rs4UQQgghhBBCnCJ/++CuEEIIIU6tB3/8jQOlZRwoLePBH3+Lni2EEEII8ZekKNFThBDiOP2B5w8J7gohhBCiWjRNi/m3EOI0dZLngZN7lqiKJr0qTpTfFz0lSNPfpX6OLqdpPpSow0zT/Po0TfMH4xtaxJvcp/lQNA0tFPzQ/EdnHi+/Lyp2Evl9JLQOStg5SdE0fPj1aYoW3B7Ctt3vi4rKaGhacHvxo/kD647fF9YroUX9oPn086A/7LUiFtP8gecG2/UH+yPwd6ANze8L9m9gW/RWNL8+ndDrB/svYrrm1/soYmtCywYj14H9dXQdfdrR1yW0v0P8PvD78CuK3ncKytH2NQ0t7BgK9LeGPzxK7vdVOOFrBPtN3wdHX9XL0f2iBPdZ2JbrAv1w9LgMvWboGA3fz+HbqwVfN+I4juqT8O0K9V9g+agNiRLYl0f7ONRvoWeF77twmuYL7Gd9QmCZ8OMEfR3ClgnfLv3fwGuH1juW8OM7JPTOPSpw7Af+jDxPhJ5/okL7RtP8EesXPj3Wa4W2209YH8WY7wtvM7RkhfaOthG9xdVSRX+fCoomv8KEEEIIUQ1vr92pZ+w+160dV9VNjF5ECHEa8Zbup3zzfLRmF0fPqlKx38mvno10s5wTPetvq9RbwG7nWlTFED3rmPx+vz74cyx+zceZ+xOxmVMoa9UperYQMZXsWYJavI8ltVTiDUm0MTcBwFtWgH/3ryyul0Dbwx5SXWWoihnN60JTTSjlpfzerBWF/iI67DqCUucM/AYjas5yFFQ0VWFLRioHzT7cWhld9zhRVSOHrCa2ptpIKS6lxeFSUEz4jAqq38fvySZc8WmcZWqOd8cCjHFZbIrzUGDw0eRgESluL2owUFimavxcL4GGR0pp4DajlRWhoaAaLfySZSOr2EOdUvB7S9kdbyYn0UL7fU7sHh/77Qa2JZqpW2aiQbEPv+YFXzCQ6NfQ/OUsa5hKuWrigt2F+DUNBQ0MJgqNGmvSzWh+DxfvKcOvqKAoLKttx2M003lPPopPQzEaUBULPq0cg19lQV0rrYxN2OTZimIwc86eI2xMtVJktdJ5Tz6qLxBg1hSF32vF4VL8tM9zBoJ0BiN+ReO3dBulNjudcg7rQSlNAcUPqsHMgrpW6hS6aFDkRfVrwYAlLGiYhMmn0XFPAapiRFNUNEVja7KNQ3FWzt5zBJPPz+/p8eTbLbTPzcfsVQLhQs2PovnJcxjYmmynVaGfFJcXxVMeWAdVZWkdB7ULSsl2GdA8peQkWNiWbMXo99F1txNUA6BQpvpZXMeOUTXTdVdRoF8VUA1W/B4nqEaUphezxf07B1U3mcUuGh4qQVHAbTSyrE48pZqLZM2K32DirLwizJqK0edjQR072Voy2Tnb8RkNLK8dz5kHnNj8Jpamq5g0A+3ziig3KCypl0D9QjcNi4KBdJ8Xv9GEHw2jpqB5y1AUlSX1E/EZzFywuwCl3IWmGvEbFXYkOTiYEEf2oUJq++PQ6p7Jz65f8KqGwL70a2hoKIrCHruRHYlGWlpbsca7mQbG2pQV7aXVYVewbwP9srReEomucs486Aa/j0X14kA1cW5uMQZf4KKIoqionjL8qooSl47izGdl7ThKTSbiFTt+zY/Tk8/5uW4Unwe/wYAanwWZrfDuXIShvIxNdTNxGVVc/lLMfjhnTxF+kwHFr7HfbmJbnJ8O+0rZmGrlsN1E/SIPzUoM+H3uwH7yulFUI+WKjyUZRhIVA+mGFJqn9Ag/reicm2eByY7PfRiDJRW/cx/EZ2AoyGNrooFdSQ4Auu4tRbGnYCjYT1lKFpbifLxGMFiS0ZyHUDJbgyONPWU72MZh3Fo53fMtGP0qfnchSlwt1II8cho0ZCcFGFBp5UsmZc82yhLTsLjdeH0uDHGZ/JRQxDn5EO904/U6Uc0OcjIcWE1pNErqHL0Jx8115Hf8e3/FaIrD0uJKnE4nJpMJs9kcsZymeXGt/RRMVuwtrw78f92ngIq97XURy1ZFMneFEEIIUS23tG3AwWFXcnDYldzStkH0bCHEacZoTQWPG4e97gk9kuzZlOOtMP3v/Ii318fpy6fYe+iEH07/EUp8hytM1+f78lGt8WAwRu8iISpltWdhUC1YzMnYzSn6sZqQ1AqL24WbMmzWWpiMCVCaj8EYh+IuRtE0bKZkFNWCQVOxWWsRn9QC1RcIWCmuYhJs9dBUFQ8+TIoNiz2TOEsq+f4iDJYkVMWI6ivDYk7DZE3DakzCoFpx2OtitaRjsWeS6GjIYX8BNmsqZmsaBoMdFQNWt5syvCgmGzgPY1BNqD4vuIuJM6djsaRBaT5GczKm8jKKVA8GxYji8WIt81BgBrM1FUrzMZmTMahWVL8fvG5UwGZKpVhzYtAMWMzxmO0ZWGy1sGHBiQeXUQFvOQa/HwMGbKYkirQSjIoFkyUJszUdnIcD/7oL8SgKNlsGFlMS+f4iTKZ44iy1KNJKMClWjKoNoyUJQ5kLq7UWBWo5it8H3nJUv4ZJtRNnrhXoO03DoFpQy0oxqRaMtjRUnwe/akA1x2EyxAX71gu+cjTVwGHVjcGvoJQ5MWDEbE7CZkkj31+E2ZKKocyFzRpo32RKwmxNxWxNwWSMw2B0YC4ro1gtx27LxGxOQkVB1fyoPh82UzKa0QKuQlRUzH5w4cFrsqBqoPo1lLJSrGXllCtQrLlQfRpG1YzJkhzcT0kYNRWHvS5Jjkbk+4swmpMD+1U1Y/P4KdFK8aPhsATW0+qoi9WWiaIp+FQDyQnNwFeOETNeg4rJGjgGzKZ4Sox+VL8fS5mbMnwYzAmYjHGBILi3HINmxGJNg7ISFJ8Xxe/Fak7GiQsVA4rfh+LzYHC7sVvSOewvwGJJxmCw4bDXxWZJDe5LCyZbOoYyNyZbJqYyN8UGPwm2OpRp5VhNyfjNDlS3Uz/mrG43HvwUqp5AX/l8mE1JFOLCZErA7MjE6C7FbE1D8XowlJdhjauP6vORaMmiRCvFYUohzpxCocGLwa+h+P0YylxY7BmB9XPUQy13keDIxmCwYTDacRoVVJ8Xq6MuZmsaptIiCk0aRk3FrBlw4cVvtAT3T2Lw3/hgP5bhNiiYjYkR2b/RzLYMzNZ0LOZ0LPZaGDFjT2oOHjdmH5RqblTVhMWcjC2hEfjKiY9vgoKKzVEPi60WRgxYLek47HVJiWtOiebCr2g4HA0wmZMwqTZsyc3B4yIlrhlexY9H0bDGNQCPi/j4JhgMNiyWNCz2TDwKWG1Z4C7E4PViUMw4LBkYVUv06p8Qiz0TxVV0zLCrohhRNR8Gg1X/v6IpKB5X9KJVqvpVhBBCCCGEEEIIcfqoxs29FUuBRP//ryOw5uG3uf91t4VYe6KSrP+TpUVW3giI9RqxpokIYYUuTlz0Tjhmf0c/4TQQcbD+PbZfgrtCCCGEEEIIIYQQQgjxFyTBXSGEEEIIIYQQQgghhPgLkuCuEEIIIarlP7/tIO3l2aS9PJv//LYjerYQQgghhBBCiFNEgrtCCCGEqJaHflzLYVcZh11lPPTj2ujZQgghhBBCCCFOEQnuCiGEEKJawsdpOOaYDUIIIao04z8L6NvuUVb8tCl6VkzbNuRybccxvDhqRvQsIYQQQpwGJLgrhBBCiGoZf1E7MhxWMhxWxl/ULnq2EEIIIYQQQohTRIK7QgghhKiWQW0asO/uK9h39xUMatMgerYQQohTqHGr2nyy/HHufeof0bOEEEIIcRqQ4K4QQgghhBBCCCGEEOIPoEVPENUkwV0hhBBCCCGE+JPxeHzM/XApN134NH3bPco/zx/L9InzKS0pi1hOau4KIYQQpzcJ7gohhBBCCCHEn8w7L/6XN56ZS50GaVx0xVkYjQY+f3sh/zfkbfIPFkcvLsSflIy0KoT4M/t7nKMkuCuEEEKIanljzXaSXpxJ0oszeWPN9ujZQgghTsK+PUe4/9lrGTftNu596h+8+d/76XHN2WxZt4cZ/1mApsltrUIIIYSQ4K4QQgghqumRn9ZSWOahsMzDIz+tjZ4thBDiJPS+tiPn92yDogSyiswWE/1v70ZqrQSW/bCRI5K9K/4S5CKEEOLP7O9xjpLgrhBCCCGqxage/ToR/rcQQoiT165jYz2wG5KcFk+D5pkc3l/E/r35EfOEEEIIcXqSX2BCCCGEqJbx3dpRO85G7Tgb47u1i54thBDiBJktJtIyE6MnY7YYSUhy4PP5cRa5omcLIYQQ4jQkwV0hhBBCVMvANtnsvasPe+/qw8A22dGzhRBCCCGEEEKcIhLcFUIIIYQQQog/kfIyDwWHK9bULS/zUlTgxGwxkZQaHz1bCCGEqIa/R/3Z05EEd4UQQgghhBDiT2bDyhw0LfKH9qH9hWxdt5eMuslk1E2OmCeEEEKI05MEd4UQQgghhBDiT2b+Zyv4fc1u/f9uVzkfv/4DhflOLrriTBKS7BHLCyGEEOL0JMFdIYQQQlTLlNXbiZ/4BfETv2DK6u3Rs4UQQpwgg0ElOT2eh26ayqhb3+LFUTMY0msCP85ZTfvzm9L72o7RTxFCCCHEaUqCu0IIIYSolpE/raWk3EtJuZeRP62Nni2EEOIEGU0Gho25mltG9GLH73n88OUqDEaV2x/uwyMv/hN7nDX6KUIIIYQ4TUlwVwghhBDVYjYc/ToR/rcQQogT949bu/LJ8sdpcUZ9rhp0AR8sfpRZv43l7e8eos8/O2G2mCKWb9yqNp8sf5x7n/pHxHQh/hyU6AlCCPEn8vc4R8kvMCGEEEJUy4SLz6BevI168TYmXHxG9GwhhBBCCCGEEKeIBHeFEEIIUS3/bFWfXXf2Ydedffhnq/rRs4UQQghx2tKiJwghxJ/I3+McJcFdIYQQQgghhBBCCCGE+AuS4K4QQgghhBBCCCGEEEIAivLXqsUrwV0hhBBCCCGEEEIIIYT4C5LgrhBCCCGq5bVV27BP+Bz7hM95bdW26NlCCCGEEEIIIU4RCe4KIYQQolpGLViHy+vD5fUxasG66NlCCCGEEEIIIU4RCe4KIYQQolqsxqNfJ8L/FkIIIYQQQghxaskvMCGEEEJUywsXn0HDJAcNkxy8cPEZ0bOFEEIIIYQQQpwiEtwVQgghRLUMaFmf7XdcxvY7LmNAy/rRs4UQQgghhBBCnCIS3BVCCCGEEEIIIcQpoERPEEKIP5G/xzlKgrtCCCGEEEIIIYQQQgjxFyTBXSGEEEJUm0/T8Gla9GQhhBBCnNbku4EQ4ihF+bNlyv49zlES3BVCCCFEtUz+dSu252dge34Gk3/dGj1bCCGEEEIIIcQpIsFdIYQQQlTLYwvX4fFrePwajy1cFz1bCCGEEH8l1cisU/4m9SvRK3Ee3Z6/+rad6rVXlFP/GuI4nPBOCH/CCT/5JP1Rr1OJv+HBKsFdIYQQQlSL3WSM+bcQ4vSked2ocRnRk4+pXCsn2ZAcPflvrcxfCvZ6mOzZZDpakOloQS1Hc1R7fQz2+vq0WI9Uc2Nq2ZpVmB56pNsbg98Lmj/6ZYWolOZ1oxgtmFUzHs0TMV2NzyJVTcbvKwfAmN4CAFNyAxSjFZ/mw6JaMNgSwRd4rupIRzHHYUhuhMfvJs4QT4oxFb/JjOZ149W8NLQ0wujXUCzxGBLqgRJ4PZ/mxagc/V6hlZfg1TzUM9fH6ysDXzmK0YJiiUeNz6SWMQOTX8GY0hjFYEG1p2JIaYzf70XzlWFIbgSaH19cOtnUwmM0YYjPgOT61C+zoHjLMKW3AL8XxWRDjauFMakeitGKovmpY66LaktE00DzlQfW32yjjrkuNsWGGp+JYkkMPFfTqGuuh99kB78XzePGkNIYzePGmNKYBDUOr+bFrJqpZ66PHwWvvzzwt9ECqgEAQ0pjfL5y6qi10OzJGEKvYTDi08oDr2ENvKYhpTEYbeD3otpTSTQkYfJpYDCgWuJQHemo8ZkkEEc9c3189mQMKY1RTDbw+/H63dSz1MevgCG5kb4+muZH87rRvOWBELdqwJuYQV01Db+vLLCfLYmotlRURxqqBhbNEGw7Dp89kTRjGipqYB2syRiS6qPa06hlzqCepT6aPRnFaAVNw5jSGBQF1Z4CgEcrp6G1EUafB9VkQzXH4YlLJdOUhUWx4scf6Ad/OZq3DNWWTLwhnjK/G0N8ForJjs1vQtP8GFMaY9A0MtVUFGsShrhMahlrYfL5QTWiOtIxxGeimu2g+TGkNA7sV1sKRk0hy1wb1Z4KJgeKLdB/Hn8ZDS2N0HzlKMH9ZlSMgX1psqP5vYF97yvHF1+LOv5EvJqHdGM6br8bm2bEmN4CxWhFtaWgxmWSoiZRx1ALvy0ZNa4WFozBfaOgecsD7Xndgf5MrI/mcaE60vBpHmqb6+DHR7lWTropHb8jBTW4rprHFXwzBfrC63NhVI3YVTupxjRURxr4ysFXhjcxgzpeB26rDc1gIsmQhMVH4D2iaYH3v9+PakvFkFCXNK8Zv+bBoJhCb9mK/L7gOmhoHhcGSwKa140xsR4+WwIpxpRAzFUxoPnKMcRnopUVo9qSIHgcqpYE8AfOL17NS6Y5E4tqQQt+3ikmG1pZEWp8Fm5PEUmGwPHg85aixmWCrwzFYAZFQSt3kmQMnNMMyQ1Qgucrv+bF63dHr/2J8ZVjSm8ReK1jUB3pKOrRc53BnoIhsW7EMseiaJoUyBNCCCHEyft00x5GLlgLwLiubbkkwxG9iBDidKJpJ5X559f8qMrpl3uyzLmUVEMKTazN9GmrSn/FolhoZWsTsWy44uJibDYbRmPlF9X8mg9VCQQbhKgeTc+280YHXf0+Pah1dKIPQsee5gdFPfpvILSD4vdBWEAjnObzoBgig0Th0yLWIazdWMvi9+FR/JgUUyAopJrA79MDqMcl/DU0H5rmR1FjBLE0LdBXmgaqQd9OTdMC6+PzQHC9fJoPQ7CP/JofVfOBajraNwCqEc3vRUFBU9Wj2cOaP7gORvz4UTWCfawF+j66XyP6KLAvNbRAe5oWbM+HYjAH2kMNXBxSjYH19Pv19db5fXgVDaNijDwGwl8r1N/hQhecos/3mv9ov/m9EcGuaPq6h/HjR/F5UAwWCO/fYJ/4FCXwf783+NpK1Z9V4cdIKGwWvnxwOzVfeaUBvMC+9AfaCfan5vOgqQZURY25HeEi9pGiBPaNPxCI1o8lTUPzewLrENb3sT5Tq1pXneYP9L/BjOYrx6sqgfdOJW2eqNA66OsS7Jdo+ns47DvF0ecc4/0bo02f5sUQdt4i7DUC26VAaF8BXn85RvUYfXUCnE4nJpMJs7nm2gwnwV0hhBBC1Kj8/PzoSUIIISoRK7i7zrUWFaXawV0hhBBC/O+d6uBu9ULuQgghhBBCCCGEEEIIIf4nJLgrhBBCCCGEEEIIIYQQf0ES3BVCCCFEtZV5fZR5g/XhhBBCCCGEEEL8ISS4K4QQQohqefnXLdhf+AL7C1/w8q9bomcLIYQQQgghhDhFJLgrhBBCiGr5v4Xr8Wsafk3j/xauj54thBBCCCGEEOIUkeCuEEIIIaolznR0pPbwv4UQQgghhBBCnFoS3BVCCCFEtbzc4yyap8TTPCWel3ucFT1bCCGEEEIIIcQpomiapkVPFEIIIYQ4Wfn5+dGThBBCVGKZcymphhSaWJvp09a51qKi0MrWJmLZcMXFxdhsNoxGuWNCCCGE+DNzOp2YTCbMZnP0rBohmbtCCCGEEEIIIYQQQgjxFyTBXSGEEEIIIYQQQgghhPgLkuCuEEIIIaqt1OOj1OOLniyEEEIIIYQQ4hSS4K4QQgghqmXiis04XvgcxwufM3HF5ujZQgghhBBCCCFOEQnuCiGEEKJaRi/aEPNvIYQQQgghhBCnlgR3hRBCCFEtiVZTzL+FEEIIIYQQQpxaEtwVQgghRLW8fMmZtE5LpHVaIi9fcmb0bCGEEEIIIYQQp4iiaZoWPVEIIYQQ4mTl5+dHTxJCCFGJZc6lpBpSaGJtpk9b51qLikIrW5uIZcMVFxdjs9kwGo3Rs4QQQgjxJ+J0OjGZTJjN5uhZNUIyd4UQQgghhBBCCCGEEOIvSIK7QgghhBBCCCGEEEII8RckwV0hhBBCVFtxuZficm/0ZCGEEEIIIYQQp5AEd4UQQghRLROWbyZh4hckTPyCCcs3R88WQgghhBBCCHGKSHBXCCGEENUyZvH6mH8LIYQQQgghhDi1JLgrhBBCiGpJth4d9TX8byGEEEIIIYQQp5YEd4UQQghRLZN7tKddeiLt0hOZ3KN99GwhhBBCCCGEEKeIommaFj1RCCGEEOJk5efnR08SQghRiWXOpaQaUmhibaZPW+dai4pCK1ubiGXDFRcXY7PZMBqN0bOEEEII8SfidDoxmUyYzafmLkfJ3BVCCCGEEEKIPxlFUaInCSGEEEJUIMFdIYQQQgghhBBCCCGE+AuS4K4QQgghqi3fXU6+uzx6shBCCCGEEEKIU0iCu0IIIYSolvHLfyflpVmkvDSL8ct/j54thBBCCCGEEOIUkeCuEEIIIarliUUbYv4thBBCCCGEEOLUkuCuEEIIIaol1XZ01Nfwv4UQQgghhBBCnFoS3BVCCCFEtbxyaXvOykjirIwkXrm0ffRsIYQQQgghhBCniKJpmhY9UQghhBDiZOXn50dPEkIIUYllzqWkGlJoYm2mT1vnWotBUWlpbR2xbLji4mJsNhtGozF6lhBCCCH+RJxOJyaTCbP51NzlKJm7QgghhBBCCCGEEEII8RckwV0hhBBCCCGEEEIIIYT4C5LgrhBCCCGq7ZCrjEOusujJQgghhBBCCCFOIQnuCiGEEKJanlm6ifSXZ5P+8myeWboperYQQgghhBBCiFNEgrtCCCGEqJaxSzbE/FsIIcSxeTUP5Vp5xDSPVl5hmhBCCCFELIqmaVr0RCGEEEKI49VoylfsKHAC0DDJwfY7LoteRAghRBU8fg8m1RQxrdxfjlmtfFTtvLw8kpKSsNls0bOEEEII8Sdy8OBBrFYr8fHx0bNqhGTuCiGEEKJaXr20PR2yUuiQlcKrl7aPni2EEOIYogO7QJWBXSGEEEKIEMncFUIIIYQQQoi/GMncFUIIIf4aJHNXCCGEEEIIIYQQQgghRAUS3BVCCCGEEEIIIYQQQoi/IAnuCiGEEKLa9jnd7HO6oycLIYQQQgghhDiFJLgrhBBCiGoZ9/NGsiZ/SdbkLxn388bo2UIIIYQQQgghThEJ7gohhBCiWsb9vCnm30IIIYQQQgghTi1F0zQteqIQQgghxPFqPOUrthc4AWiU5GDbHZdFLyKEECfE7/dSXLInevJfTlHxLmy2dJISGmI0WqNnnxC/30t+4VaMhkA7hw8fJi4uDovFEr2oECIGZ+kBHPZalDhzsVnTKCnNIzE+G4Ci4hzi4+pSWLSTxIQGlLoO4Nd82CypFDv3kJTYiJKSXOLj6qFpPrxeF+6yAmzWVAqLc0hJakpJaR5x9izKyovxekspLy8iIT6b4pI9KIpKUmLj6FVC03wUl+zFaknFYDBR4swlIb4exc69xMfVwef34HYfxmpOoahkDyajDbutFgaDWX++230EqzWFktJ9xDvqAFBSsheHI5OCwm0kJzWLelUoLy9GVY34NR9u92H8fh/lnhJs1hTiHFl6P5S49mMxx2NQLRQV78JgsFBeXkxyYhN9HUI0zUd+4TaSExtTVLIXmzUZg2rRt6mwONCms3Q/fr8fuzXQtyZzHFZTEkUlu0lObExhUQ6a5sdkslHuKSExviGqaqC4ZC92azqKasDp3EdCQn39db0+N+6yfOzWQN94PS4U1UCp6wB2Wy1U1UB+4VaSk5pQXJKL3ZKGohpQVQMFRdtJSmhEccke4uPr4vf5MBjMeDwlqKqZouJdJCU2pKh4DwnxdfH7PbjLClFQsVgScZbuw2JJwmiwoqqGo33n3IeqGLBYEikuCRxDhUU5GI02bNZU/fjTNB8+Xzku9yHi4mpTULiDpISGlLoP4vd7sVtr4fW6MBjMlLoOoml+EuLr4fN5Ij5Xikv24vOVEefIChyzCY0oKNpJUkIDfL5yysry8Wle4h118PnKMRpt+nNLXQdRFSMmkyO47aWoqhGDwUxRyW7i42pz8PAG7NZUHPYMfRs9nlLKPcVYLckUlewhJampvk9K3Ydx2GrhsGdW+vnn93spKNyGwWDBWbofqyUZZ+l+7La0wPGQ1Jji4j3Ex9WlxJlHnCMLl/sIPl85cY5Mip25xDtq4/W5KCsvwmpJpsSZS2J8A/ILAvs7v3AbKUnNcLr2YzEnUOo6iIJKnCOLgqIdmIx2yjxFKJpKQlxdFNVIeXkRZnMCqmokKbFR9GrH5Pd7OVKwBVOwXzXNh18LHEsOeyZHDhdQ7s0lzpFIQeF2TGYHfr8fh60WGn5Sk1tEN3lCJLgrhBBCiGr57479PL5oPQBjLmjNpQ0zohcRQogTtuTXcfj9nujJfympSS0oKc2jwxn3Rs86KavWvYbTdSB6shDiOCTGN6CweCcWazJl7nwS4utTVLwLgPi4uhSX7CExvh6FxbuxmJNAgbKyAhITGlJYtIN4Rx2KnXsBAkEwxYS7vEBvx25Pp6ysEINqBUXB73Pj9ZURH1cPj6cYd1lB1BoFOGwZuMqO4Pd7SIyrT2HJLuIcWZQ48zAYzJgMVtzlRRiNVowGa4V2rJZk3GX5+jYAxDtqUxwMqhYV745YHsBosOL1uVFVIyZTHEajDQUoceYBkBBXj6KS3disKZR7nPh8ZSQmNMTlPoDXW4bf741uEoA4Rx1KnHux2zNwuY+g+T0kOOpS5NxDvKMexc7dWMyJaJpGuacoEAhTDHg8JSTE1aeoZBfxjrq43IdQDWb8vjK8vjIAHPYMSl2H0DQfCXHZFJXk6K9rNNgxGC2UleUDYDCY8Pk82KwpuNxHAEhIqE9R0S4ctnTcZQX4gp8voX6Lc9TGWZqHqhjx+T0YDTa8Ppd+3ISWMxqsKKqK3+fF5y8nIb4eztID+ILrGdoOmyUVr8+Nx+skMb4+hcW7MJkcoICn3Kn3MYDRaMFqTqGkNE+fbjEnoCgq7rICjEYbXq8Lu7UWRqOFopLdGFQzPn+53gdWSwp+XxnlXqd+TJqNcZR7SwE/ZlMcJpMDZ+l+TEYbHq/r6HOtKZSVFaJpPgAMBsvR7Qm2ZbMk4wr2b0JcXYpK9mA2xwEK5eXF+jaG2CyplHmK6Hz2SH1aLBu2fsSR/N/1fWW1puB2HyExPpvC4hz9fWezpeFyHcJqSUZRDIFgePB9YjbFoagqZWVFOGzpOF0H9X4MvQfMpjjKPSXE22uDaqC4ZLf+3kmMy6bEtQ/wB4+bNFzuQ7Rq+k89YH08wj+jTUYHoKFpGp3aP8jBgwcpdP7GvoOLsJgTKC8vwWJOwl1+hHq1LyC7Tvfo5k6IlGUQQgghRLVc2jCDnwdezM8DL5bArhCixiiKEj3pL0lVjdGTTlpNtiXE6cqgmqIn6ULvMUVR9eVC08LffwoKBkNwvhJ6jgFFUVEUVZ+mL69UHnqJ9b4On3b0bwU1xrorigEqaydqPUJUNfCcykRvU0jg/5Wfm4/2nwGFyOXC+za0HYG/Q/17dNvC+z9EVY1RLR5VcfnAOhvUo9nF4dtEWEv66wf78ej06P8fXX+liv0bWo/AcRDcN/prK/r+qkysfop1DEYLz6Q++noq4bsr/FgKF95PRG1PLPo6hh2T0cdmdP9VJtRfRkPgLpTQuhw9RkJ9eLS90N+heYpydD2OLh95XFXVh6pqDG5zYLtD63KiwvdtqM3wfjAajmZLE7XPqqvqPSaEEEIIIYQQQgghhPgLqiwkLv5OJLgrhBBCCCGEEEIIIYQQf0ES3BVCCCFEte0tdrG3+GjtLiGEEEIIIYQQp54Ed4UQQghRLWOXbKTuq3Oo++ocxi7ZGD1bCCGEEEIIIcQpIsFdIYQQQlTL00uPBnTD/xZCiFPp8KFi+l05kS4dRx/34/13FgHw1Ogv6NH1KX7flBvd7N/GU6O/qLD9VT2GDZ2Gq7Sc3zfl0qPrUzw1+ovoJv9yFi/8nS4dR/+pt8VVWs6wodMijsdY046lusd0WZmHd95ewNYt+6JnCSGE+JOT4K4QQgghqiXLYY35txBCnEqqQSUjM5FaGUcf6ekJqGpg8Jj4eFvEvFoZiTjiTp9zVEKSvcL2W22BkbmNRkOFeampcRGjqovTy4vj5/Hu2wvx+fzRs4QQQvzJ/aWCuy6Xi+eee45vv/02epZu8+bN9OnTB7PZjKIoKIpCkyZN2LBhQ/SiJ0TTNJ577jm9zWbNmpGTkxO9WEwej4chQ4boz33uuefQNA2AX3/9FYfDgaIoDBo0KPqp1fLzzz9jsVj01/3oo4+iF6nUs88+qz/vZB7PPvtsRHuDBg1CURQcDge//vprxLy/s/fffx9FUbj88stxu9369M8//7xCn1X26NatGyUlJRHthhQVFfHcc8/RpEkTfXmz2UyXLl2YO3cuPp8v+inHJScnh2bNmqEc53GpaRobNmzglltuISUlRV+XOnXq8NBDD3H48OHop1Rwsm2Ejq0TfVS2XbH6NC4ujgEDBrBs2TL9vVuVmmjD5/Mxd+5cLrjgAgwGA4qiYDAYuOCCC6rctyf73g0/zmqijXCV9UefPn2Ouz8Adu/ezYgRI6hTp47eTkpKCrfffjs7duyIXryCk+3TcJqmsWzZMvr06UNcXJy+HnXq1GHEiBHk5eVFPyUmp9PJtGnTaNeunb4uoffu999/f8x1CW1Lly5d9M87g8HAueeeywcffIDLdezat5Vty8mY0usczq+Tyvl1UpnS65zo2UIIcUokJzt4ZeqtzPjyPv3x/qd30+7MbABGjb4qYt6ML+/jqmtOn3PUsHt7Vtj+mwd3BeCSS9tUmDf6qX568Ff89YwafTXfLBhF8xa1o2cdF6+36u8eQggh/rz+MsHdjRs30r59ex566KFKf7SuWrWKc889l7lz5+LxePTpFouFWrVqRSx7ohRFYdiwYfTp0weALVu2MHr06IjXiUXTNCZOnMgbb7wBQP/+/bnvvvtO+gf08dI0jffff5/y8nJ92rRp0yrtO1HzNE3TL0R07doVq/VopkhNBLhXr15NmzZteOihh9i2bZs+3ePxsGjRIvr06UPPnj3Zt+/Ebq3yeDw88MADbNmyJXpWTC6XixEjRtC6dWumTZtGfn6+Pi83N5fnnnuOFi1a8P3330c8L1xNtFETvv/+exo3blyhT51OJx9//DGdOnXi1ltvxel0RjwvXE20UVhYyIABA+jTpw+LFy/G7w9kUPj9fhYvXkyfPn3o06cPhYWF0U/906mqP+bOnUunTp24/vrrq9wWn8/HSy+9RIMGDXjhhRfIzT16u2F+fj5vvvkmzZs355133qk0UFwTfRp6b3Tq1Im5c+dG7MPc3FxeeOEFsrOzq1wPgBUrVtCyZUtuueUW1q5dq69L6L3bvXt3brvttkqPkX379tGzZ0/69OnDokWL9M8hv9/P8uXLueGGG2jfvj0bN1ZeHqGqbTkZ3bNrsejGi1l048V0z67e560QQgghhBBCiOP3lwnuzp49m02bNkVPjvD2229TUFAAwLXXXsvmzZvJzc1l9uzZJCUlRS9+wmw2G5MmTaJRo0YQDJZ+/PHH0YtFWLFiBU8//TQATZs2Zfz48ZhMpujFatyBAwf47rvvAGjQoAEAP/zwA6tXr45a8th69OjB7bfffkKPdu3aRTdz2jl8+DDLly/HaDRy/vnn69O9Xi+///47AGlpadx8880V+i/8cfnll2M0GsNaDmTWXnvttezevRuAK6+8ksWLF5Obm8vMmTM588wzAfjuu++46aabYmZTVubll1/m008/jZ4ck8fj4Z577mHixIkA1K9fn7fffpvdu3ezevVqBg4cCMChQ4e47rrrWL9+fVQL1W+ja9euFfos1uOf//wn8fHxAKiqypVXXhnRzvLly/nHP/7BoUOHAOjQoQPz5s0jNzeXxYsX68tPmzaNvn37xgwC1kQboaDbZ599BsE2FixYQG5uLvPmzaNDhw4AfP3119x+++0VLjC1a9euwrbHetx6661kZWXpz7vssstwOBw11gYx+uO2225j7dq1Fbbl448/5oEHHqiwLQQvkkyYMIF7770Xv99PQkICzz77LDt37mTz5s3cd999mEwmPB4Pt9xyS8w7O6rbp4RdqJswYQIACQkJTJ48udL1+PLLL6ObgGCfXHrppfp7d+DAgaxevZrdu3fz9ttvU79+fQgeI6NGjaoQJC4sLOTGG2/Uz+/R6/F///d/mEwmNm3aROfOnVm+fHnE8zmObRFCiNNJeZmXGZ8s4+rLJ9Cl42gu6vwkzz41myOHI787vf/OInp0fYolizYz6sGP6dJxNL27P8Pc2asilvu72bnjIA/e9z4XdhpDl46juaH/ZH76fkPE51OoRu/0/yzgw3cXc1HnJ7mw0xjGPTETtyuQaFJW5uHrr9ZwQ//Jeo3fW298nZW/7KjwWQfgdJbx2qRv9P3SpeNoelw4jmefms3BA0XRi1NW5uHDdxfTu/szdOk4mit6jue7b9YRo2l9fd9/ZxFr1+xiyM1v0KXjaC7sNIYH73ufnTsORj8F9L74gIs6P3nM9QE4eKCIZ5+aTY8Lx+nH1oP3fVBp+9UVq+ZuqB+v7Dle78f+fV/kw3cX6/sm1B9ff7UGt9vDbTdNrdDO4UPFTH5xvt6/F3V+ktGPfhZz258a/QX9rpxI7t58Pnx3sf7alb23hBBCVN9fJrh7LE6nUy+9UKdOHSZMmEDTpk3JysqicePGFYJjJ6tBgwY899xzqGqg64YNG8aqVbG/1OXk5HDjjTdSUFCA2WzmrbfeIjs7cJvYqfbDDz+wadMmjEYjw4YNw263U15ezocffhi96DENGzaMqVOnntCjd+/eEW1Mnz4dTdNwOp2cffbZEfP+rlavXs3GjRtp1qwZLVq00KcXFBToGXUXXnghb7zxRoX+C3888MADEVm/mqYxfvx4PbP22WefZebMmXTu3JmsrCz69u3Ljz/+SPfu3QH45ptvmDNnjv78qixfvpyxY8dGT67Ul19+yZtvvgnARRddxKpVq7j55pupW7cuZ5xxBtOnT9dLdBw6dIjXXnutwhf46rYxePDgCn0W/ZgyZQpt2rShuLgYgKeffpprrrlGb+Pw4cPcfffd+sWhxx9/nMWLF9OrVy+ysrLo3LkzM2fO5O2330ZVVb777jt9nWuyDaL64/bbb2fx4sV06dKFrKwsevXqxU8//cTgwYMB+OyzzyoEEXv37l1h+2M9Lr30Uj2r+/bbb4+4o6Am2nC5XDz55JMUFBSgqiqfffYZb7zxBm3atInYllB5jDfffFMPWIZbuXIlY8aMgeAFsl9//ZUHH3yQ7OxsmjZtygsvvMCHH36Iqqr4/X4mTJhQ4Q6F6vYpwPr16yMu1P3222/cddddla7HuHHjKgTvS0pKePDBB/U++eijj5g+fTpnnHEGdevW5eabb2bBggU0bdoUgufN8IsZmqbx/PPP6/3UpUsXfv/994j1GDNmDGvWrKFRo0YUFBTw6KOPVri4c6xtEUKI00VZmZeHR3zIyy98Tf36afS67Azi4q3MmbWSe+96h/z8yLsaysq8PD7qM35bs4tel51Bo8a1aNAoPWKZv5Mff9jILTe8zpbN+7i0VzvObN+AXTmHePThT5j1RcW70D5+/2def+U7OnVuQucLmlGvXipWmxlniZuxo7/gqdFfsH9fIRdf0pqLL2nNrpzD3HPndD58b0nEd7t9eQXcdtNUPnh3McnJDq646uzAvomzMGfWSu67+52IoKKzxM1D//6QVyd9g9FooNdlZ9CgYTpjHp3BtDd/0peLtmbVLh789wfk5RXQ67IzaNOuHj8v3sItN7zOTz8cvftF0zTmzVnNoOtf5efFm2nTrh5XXHU29eqlMGfWSm66/lU2rt8b0fYvy7dz0/WvMmfWSurVS+GKq84Otr+ZQde/yk/fV69k4PFwlrgZ9eDHfPDuYtLS47niqrO5+JLWlJS4eXXSNzw+6jPKyjwkJtrp07c9mVlJKIpClwtb0KdvexIT7QBs2byPIbe8yccf/Exiop3LrzyLNu3q8d1/18XcdoJB5XFPzOT1V76lddu6x3xvCSGEqJ6/TXBX0zS8Xi8ADRs2JDExMXqRGtO3b1+GDBkCwUDdAw88UOFHvMfjYfTo0XoA7sknn+SCCy6IWOZUcbvdvPvuuwA0a9aMfv360aVLFwgG+vbv3x/1DHEq/PDDD2iaRvfu3UlLS9On7927V8/a69ChwwlfeCgoKGDZsmUAtGjRQq85Gy4xMZEnn3wSszlQN+3LL7+sEFSNFh6cbNmy5THXq7CwUK8fnZWVxdSpU0lJSYlYRlEU7rjjDs4991wA5syZE3FLfU20cTw+//xzHn30UaikNMrixYtZsWIFAH369OGhhx6qkGGvKAo33HAD//jHPwCYOHFiRJ3XmmjD5XIxdepUvT8efPDBCm3YbDYef/xxmjRpgqZpTJ06tUIw81iWL1/O0KFD0TSNDh068PTTT1d4nWM5VhsbNmzQy2hce+219O3bN+zZATabjdtuuw2j0YimaXpmbYjX6+X555+ntLRUv0DWpEmTiGUInpOvvfZaABYuXBhRY72m+nTNmjV64H748OExL9SFr8dvv/1WIQt2zpw5/PRT4EfmyJEjufbaayu8d7Ozs3niiScg+F6fO3euPi83N1c/t6elpfHaa6+RmZmpzw9p2bKlHhCPdXHneLblZOwqKmVXUWn0ZCGE+NPSNA2b3cy0D/7FS68NYtToq/loxjDOODObHdsPsHJFZD13TdNISXHwxrTbGTX6al6Zeiut29SNWObvxO0q54abzufTWfcyavTVTHr9ZkY/1Q9FUZg3ZzXOkqPjSQCUlLj5vyev4ennr+fp569n4C2B3x8zP/+FH7/bwBlnZvPJzHsYM64/Y8b1571P7qJuvVSmvPItvyzfDsE+/vC9JezZfZh/DevBW+/ewYMjr2DU6Kv5dNa99O5zJjk7D7Fq5U79dWd+/gu/rthO9x5t+HTmPfq6vvrGrezde0RfLtrPizfT+YJm+nMmT7mF0U/1w+/389aUH/QA5NbN+3h54tfYHRZee3Mwk16/mQdHXsFb797ByP+7ilJnGU8/OUtf/uCBIl58/itKnWWMfqqfvg2TXr+ZiZNvwu6wMP6ZOezYfiBqjWrW4kWb+XXFdgbd2lVfhzHj+vPR58Np3aYuebkF5OUWkJmVxD0jenPmWdlYLEYGDe7KPSN6k5mVhKu0nJdf+JoD+wu55fZuvPfJXTz8aF8mvX4zY5+5tsK2h5QUuzl8qIT3Px3G089fX+G9tX7tnojlhRBCVM/fJrgbLjQ4zaliNBoZO3asfivvd999x6uvvqoHz7TgLa/Tpk2DSoJJp9LWrVtZsmQJAB07dqRevXp6reBNmzbxww8/RD1D1LSSkhIWL14MwCWXXBIxb9u2bRQXF6MoykllMbvdbg4cCHwZzMjIiLgNPlyDBg30wM/evXurrKnp8Xh45JFHWLFiBU2bNuWJJ57QA8OV2bBhg561PmTIkJhBN4KB5u7du6OqKnFxcRw5cvRLdk20cSw5OTk88sgj+P1+srKyGDduXIXgXihYTjCz02azRcwPMZlMDBgwAIJ9unbtWn1eTbSxc+dOPUDcuXNnvaRKtHr16unZ8cuWLYuoZXsshYWFjBw5Ur+jYMKECaSmpkYvVqXjacPtdtOjRw+Sk5Pp06dPpRcLGjduTO3agYE/ossh7N69m4ULFwJwzTXXcN5550XMDzEajfo5LjExMaJmc0316Z49R3+EVBYMNRqNekkUl8sVUe/a6/Uya9YsCN5dcuutt1b6mdCxY0cyMzNJTk6OuBi3Y8cO9u4NZMf06dOHVq1ahT0rUrdu3ahTpw4A8+fPj7i4czzbcqLGLF5P9mtzyX5tLmMWVyy/IoQQf1bX33g+DRsdrRXuiLPSo1dbAPbtC1wIC3fe+c3IzKp+ube/guwGaVxzbUeMRoM+7ZwOjchukMbhQ8W43ZGf29kN0jinQ6B8XciRIyXMm7Mak8nA0GGXkJIap8/LzEri3vt7o2kw98tV+Hx+igpdbNu6n/rZafTo1Tbis9JoNJDdIJAwEcrcLS528eN3G4iLt3LTrV0jBoRr064e/Qd00v8fLT09gdvuuEh/jqIodLu4FRdd0joiADn3y9WUFLu5cdAFtGlXT3++oihc2rudvvyKZYHvDksWbSZn5yEu63MW3S5uFbEN53RsRP8BnSgsKOWb+Ue/A54KoT5KTgkMmhqSmGjn9f/cxjsf3UmDhlVnnv+2ZhdrVuXQqk1drru+U8Sx0PWillz1jw4xL4QA9Ol7FnXrHU3acMRZOe+CwB1COTtPTWkKIYQ4Xf3pg7uhUdsffvhhfdqVV16JoihkZ2czf/58HA4H8fHxekbUTz/9RHx8vL7M8Y5efiJSU1OZPHmyXst39OjRLFq0CIBFixbx2GOPwR9cZzfks88+o6CgAEVRuO6661AUhe7du+sZkf+LgdVC2aUOh6PSwcRcLhfvvPMObdq0QQmO2p6cnMxdd93Fnj17yMvLIzs7G0VR9Nv0o2nB0d8HDBhASkqK3k5GRgZDhw7Va93GElrHbt26UVJSQl5eHiNGjCAjI0Nvp0GDBjz++OMUFVWsLxVu9+7drF27lszMTNq0aRMxL7T9GRkZev3mE2EymYiLC3wxdjqdFQJiIQUFBfrt2GlpaRGlHaLNnj2bt956C1VVeeWVV2jYsGH0IhX8/PPPlJeXYzQaKwSwoz311FP4fD7WrVtH27aBH0zUUBtV8Xq9PPPMM3oG/f/93//FDCCHgl12u10PiFUmOzsbuz1wm1p4QLcm2ti6daseuO7Ro0elAVGCwTuC+3nNmjXRsysVXv7gX//610ndUXA8bZx//vnMnDmTI0eOcMMNN0TP1m3btk3PxI4+T27atCkimFlVf9xwww1omkZubm7EsVRTfRpet72q939ofY1GI8nJyfr0UA1ugE6dOlGv3tEfh9EaNWpEXl4eR44c4YUXXtCn5+fn63eotGjRIuKHWrSkpCT9WF+3bp2eqRuaF1LVtpyIZ5cePbeG/y2EEH92deoePVeHpKUHavQ7nWXRs2j4Ny7DEC0tPQF7WLAUwGw2kpTswOPxUVYW+EwKyaqdhM0eufyeXUfYs/sILVvXoVHjigNuNmpci7S0eDau30thYSmJSXYmT7mF9z+9m+RkBwcPFrF82TY+eHcx9971Dm++Hpmksn9fIbtyDtOocS0ysyreudnujPqVfl6e2T67QqDeYFDpfEEzADas34PbVc7OHQexWEy0P6fi92ODQeWi7q0BWPVLIJt404bA95pO5zfFYKj4c/u885tisZjYsG6vXvf2VGjVui6qqvDi81/x2COfsHjh7xWyrY9lzeocNE3jgi7NccRF/pZQFEXvk1DmdbjGTTKiJx0zmCyEEOLk6J8269evJz09XQ9iVfZo2rSpfkv56a5jx476rd7l5eWMGDGCNWvWMGTIEMrLy//wOrsEa5KGBsNq2bIl55xzDgSDzJdeeilUY2C1U2njxo20b9+eQYMGRdSYLCgo4NVXX6V58+Yx62CGczqd3HrrrXTq1ImPP/44InvvwIEDTJkyhRYtWjB27NhKA6Ihc+bMoUmTJrzwwgt6lizBLNAnnniCxo0b88svv0Q8J9yCBQs4cuQI7du317MSCWYzhvo+MzOT2bNn06lTJ+LiAlfUwwPIhw8fDmvxqNTUVC666CIIBor/+9//Ri+Cpml88MEHelDrkksuqTSotXXrVoYNG4bf72fs2LHHDLISbD8UpK5fv/5J1emsiTaO5eeff+Y///kPBOsb33jjjdGLVMvOnTsjMiJPRngb4eUEQgNrVaZevXp6dnBOTk707Ji2bt2qD6LVpEkTRowYUekPnsrURBshLpeLN998E6/Xi6Io9OvXL2J+KKs7Pj6+wkWS41VTfXr++efrQdEpU6ZUKMVDsG9CpSWaN29Os2aBH4cAu3bt0i80du7cudL346mQl5cXkbl/PNtyourGBy5WRP8thBB/ZlarieSU2HdAARzcX/ECWCjwezpIT4+PyIQNV+oso7g4MlkkMdGO2Rz5+VZc7MLn87Mvr5BJE+fz3LgvIx5vvP495R4vRw6X6JmmblegFED3C8ZyzeUvMGLYu7w26RvWr9tDYlLkZ4zP58fv95OcEofFUjGZJj7BisUS+zO3YaNaMb/DOBwWCO7/sjIvhw8Vk5BgIy0t9r5PTnZgMKgUl7hxuco5fLgYi8VErYyE6EUh2E+OOAvOEjcery96do05s302dw2/FFVV+PG7DTw84kN6XfwMN/SfzEfvL4l58SJa6D2wfNm2CvvuuXFfMv+rNSiKwp49R3CVHg1UH+u9JYQQombpwd1WrVoxdOjQyLkxjBgxosqMo5p21113kZubqwdRCQ4yk5uby4oVK7jwwgvZvn07W7ZsoVOnwG03nTp1YsuWLfoytWpVvEpcU4YPH07//v0BWLFiBRdffDGbNm2CP7jObsgvv/yiD9bVv39/vdar0Wjk+uuvh2Ag+v333692UKqm5OTk0LdvX73fevTooY9kv2DBAnr06EFpaSlDhw5l165d0U+HYBmE/v3766Uw6tevz+uvv87OnTvZuXMnr7/+uh7Yeeyxx5g4cWKl27906VIGDhyIx+Nh2LBhrF27ltzcXGbOnKnfcn3o0CHuv//+CgMVEcwW/eabbyCYZR6eMVtYWKhv5+rVq7nvvvtYtmxZROAlFEBu1KgR8+bN06eHKIrCAw88QNOmTdE0jYEDBzJ58mQ9QLN//36GDx+uD4zWvXt3fd9HKywsZOjQoeTl5dG9e3fuvPPOmF9yozmdTj1DsV69ejgcDoqKinjuuedo0qSJHqjOyMhgxIgRMbPna6KNqrhcLp577jnKy8tRVZVRo0bpGc/RQjW6y8vLj5nVfuDAAUpLA3VFw8td1EQboQxju91+zPOWqqr6vqoqIz1E0zQmT56s9+PJnMtrog2C79dZs2bRuXNnpk+fDsBtt92mDwIYEtqu5ORkatWqpWf3t2vXTi+/ExcXxy233FJpH9RUn7Zu3ZpHHnkEgnV9r7zyStauXYvP58Pj8TB//nx69epFXl4eqqoyZsyYiNfbt2+fflyE3rvLli2jT58++sUdg8HABRdcwNy5c/H5Kv7Ys9ls+vodK+PW6XTqF2Lz8/Mjyjsca1tOxpSe7bmwfjoX1k9nSs/20bOFEEKc5g7sL2TOrJV8OfPXiMe8OaspLDhar72szMPjoz7j04+Wkt0gnX8/eDnTP7yTed89zDc/jaT/gMAYDDVBNRz7O++JMJsMqOrxt2m2GGNm9tYURVG49p/n8dW3D/HYmGtof3ZDVFVhV84hXnnpv9x201T25VUsPRLL6pU7K+y7L2f+ysKfNlX6m0oIIcQfR/80URSlypqXBAd/CgUy/yhxcXFkZWVFBGWSk5PJysqiVq1aWK1WMjIyyMzMxGIJXGW1WCxkZmbqyxgMR2sD1TSTycT48eP1rMNQpuQfXWeXYFBx+vTpaJqG2WymZ8+eEfPPPfdcWrRoAcDMmTPZufPoQARVCZXBON7HoEGDopuolBasTxwKwIwbN4558+bpI9l36dKFefPmMW7cuCq/OLz33nt6IPSiiy5i1apV3HHHHWRnZ5Odnc0dd9zBqlWr9IzXxx57TC+jEa2srIy4uDh+/PFHXn75Zdq0aUNWVhZ9+/blxx9/1AenW7x4ccwM6P379/PLL79gt9v1zOmQPXv2RNThrF+/Pq+88gqbN28mNzeXefPm0aNHDwgGb/r06cOMGTPCWgjIzs7m+++/p3///noQOikpMMJtZmYmkydPRlVV7rnnHmbNmhVzgEFN05gyZQrfffcdSUlJjB8/PuZysbjdbg4eDNTKSktL47fffqNNmzY89NBDEbVKDxw4wAsvvECTJk0iBoaihtqoyk8//aQvf9lll1V5oaVz584QfA999tlnlR5rXq835v6ghto4kcBaRkZGxG3/x7J+/Xo9kNq6dWuuueaa6EWOqbptbN++naysLOLj47nqqqtYvXo1NpuN559/nldeeSWiLIPb7ebQoUMA1KpVi3379tG1a1cGDRrE2rVr8fv9EAxiTps2jVatWjF58uQK/V5TfaooCiNGjOCDDz4gPT2dBQsW0K5dO4xGI2azmV69erFt2zZatGjB999/rw+aFxIqPWE0GrFYLIwYMYJOnToxd+5cPbjv9/tZvHgxffr0YcCAARUyasPrE8+aNSviroJoixcvjhisL9yxtuVkXJRdix+v78aP13fjouyqg+hCCCFOH/HxNgwGlR4927Jg2eMsXD465uObBaNo3qI2v2/MY9nPW2naPItJU27m6n4daNS4FnHxscuLGQwqqqqSf6SEsrKKn/l+v0YlX8nYuT123ddDB4sBSM9IwGIxkpoWT1GRi0OHAtOj7dtXgM/nx2IxYbGYSE2Np6zMw4EYmd8E2y8sKMVkMqL+Ab8VHXFWLu3djpdeG8QPSx7jtTcH07RZJnt2H+arORV/y4RLD2Yfj332ugr7LPwx6fWbK5TkEEII8ceJuFRYr149RowYET5JpwTr3kYPmiMCgbYHHnhA/7/BYOCOO+6oUD/yVNuyZYt+i/5FF12kZ5mGZGRk6MGYvXv3nlCg7FTZuXOnfhvzhRdeyLBhwyoE4w0GA/fee68+YFK0Q4cOMWnSJACysrKYOnWqXl84XEpKCuPHj8dut1NeXs5//vOfCoGgkHvuuUcP1oVLTEzU64d6vd6I0g8ha9euZdeuXbRq1apCTd3S0lIyMzMxm83cd999bNq0iTvvvJOmTZuSlZVFr169mD9/PpMmTUJVVfx+P48//njMIE5RURHFxbG/ZBLsNyBmBiBRtaEnTZrEWWedFb1IpTwej561/Msvv3D55Zeze/durrzyShYvXkxubi6LFy9m4MCBENzuK6+8ktmzZ9doG5VxuVy88soraJqGqqoMHz680gHOCL5fQhc+Jk2axCeffFLh2NA0jRkzZvD2229HTA+piTZOhMFgqPBeqYymabz55pt63dW77777mFms0WqijcOHD1cIWLpcLl577TVmz54d0V9er1cPeu7du5devXrxyy+/0KFDB+bNm0dubi5r165l2LBhmEwm/H4/w4YNi6hRe6KO1afl5eV6wLkyBoMBn89XYd+H17y99957mThxIgkJCUyePDnmHQafffYZ/fv3j7g7oEGDBlx11VUQDLSPHDkyZpb41q1befjhhyusQ7jj2RYhhBCiurLqJJFeK4HNv+eRn1/54L4hoTIOjRvXIjExsgRDWZmHlcG6tiFZWUk0aJTOxvV72b6t4vflDev2xgz6AmzcsJcjRyLvwvP5/KxauRNFUWh/dkOsNjMNGqYHX7viRVOfz8+SRZsBaNEqcAE29O/SxVvw+QIXo8P9smI7Pp+fBg3TKy17UV0ej49nxs6ib+/n+W310TsfVVWlTbt6DBp8IQB5eyv+lgnXqnVdAFb+sqPK7xVCCCH+tyrcB9K/f386dOgQPZnLL79cH0lcRMrJyWH8+PH6/30+H2PGjKkQxDjVvvvuOz1z+Oqrr44ZzOrTp4+emfXee+8d1zr26NGD22+//bgfXbt2jW6iUqtWrdJvzb/++usrvW3eZrMxYMCA6MkQHHRp8+bAl6qePXvSuHHj6EV0zZs314/vBQsWxAyaKopSZZZneA3d8HqeIV999RWapnHeeedVyALs0qUL27Zto6ysjBdeeCHmPlIUhTvuuEPP/Fu/fr1e5oFgkG3atGm0bduWr7/+GpPJFFE+4ptvvuHqq6/G4/Hw0ksv0aZNmwoZxjk5OQwePJjy8nJuv/12rrvuuoj5JyInJ4eioiLefvttZs6cSefOncnKytJvuf/oo4/0QPXYsWNj1hKuiTbCrV69Wr/Q0aNHjyr3J8ELH2PHjtVfY8CAAQwaNIglS5aQl5fHkiVLGDRoEAMGDCAtLY2srKzoJmqkjVMl/CJKixYtuPrqq6MXOaaaaKN27dp89dVXFQKz27Zto1+/fowfPz7mD4e8vDwOHDjA448/zuLFi+nVqxdZWVm0adOGl19+mR9//FGvITthwgS2bt0a3US17du3j549ezJ8+HAOHjzImWeeySeffMLu3bvZvHkzL730Eunp6axfv57u3bvzwAMPxMwa9nq9bNy4kS5duvD7779z1113VXqHwTfffMPMmTP15yphJVkA3nrrLTp37synn37Knj172LJlC+PGjePss89m27ZtlQ48eKxtEUIIIWpKWlo8PXu3I2fnId549fuIQKumaXz3zTq6nfcEw/81jaIil57p+9uaXRElA/x+P598uJTlSyM/4x1xVnr3OROPx8cbr35PUdHRi54b1+/lzSnfRywfLmfnIWbO+AVvsO6tpmn8+P0Gfvh2PWeclU3rNoHA5uVXnElcvJX3pi9i3W9Hx57RNI3/zvuNH75dT3aDNH0gts4XNCO7QRpfzVnFj99viPhu88vy7Xz60VLi4q1cfkVkIk5NMpkMNG2WxZHDJcyZtTKi371eH4sXBspPtWkXWV6rrMxLcdHRQdfanlGPlq3qMHPGCr7/dn3EtpSVeXh27Gwu7DSGt9/4UZ8uhBDij1chuJuamsrDDz+MEnaLiN1u59FHH40ZiDrduVwu7r77brZs2YKiKDgcgcLxCxcu5OGHH9ZHNj/VCgsLee+99wCoU6eOPnhatDPPPFMPHCxfvpyff/45epEKhg0bxtSpU4/7MXjw4OgmKhWqP2uz2Wjfvuo6jS1atMBurzhQz/r16/V+rlu3Lvv27SMvLy/mo7i4mPT0wCit+/btY8+ePVGtBdYlOih7vPLz8/n5559RFIUrr7wyevZxM5lMDBkyRH8fhg+atnLlSu666y78fj+NGjVizZo1EeUjLrnkEmbMmKEHRHfv3s2QIUP0gKjH4+Gpp55iy5YtNG3alFGjRlU7y/yWW27hxhtvjDhvEKr1de21/Otf/4JgXerFixdHLBNSE22EzJkzh/LywKAON99883Gdu6655hrefvttvS/effddzj//fGrXrs3555/Pu+++S4sWLfjyyy8jBsoKVxNtHC+Xy6Vv47EsWbJEv4hy3XXXkZFRcfTiY6mJNurUqUO3bt0qDcw+9thjLF26NPppEMzsv//++2Meq507d+bJJ5+EYCC4srIXx1JZn3q9Xh599FEWLlwIwOOPP87y5cvp378/devWpWnTpgwfPpxNmzbp59cJEybw8ccfR7UUYLfbmThxIpmZmdGzSElJ4eWXX9b75M0334zI3s3OzmbWrFl6lvjq1au59tprqVevHs2aNWPUqFG4XC7efvtt/vnPf+rPCzmebTlZ2wtK2F5QsQ65EEKI05eiKFx/Y2fOPa8Jc2av5KrLJvD4yE95Zuwsbrz2FUaP+gxFUfjHteeSkGCjabNM2p/TkNy9+dx47Ss8PvJTnhr9BX17T+A/U3/ksivOwmBQI0oq9LnyLK646mxW/rqDay5/gcdHfsoj93/I0MFv4vX6K3y3DLFaTUx/6yduvPYVnhk7i8EDpzB61GfEJ9i4655L9TIDTZplMvy+XpQ6y/jXbW8xbOg0nhv3JYMHTmHcEzOxOyzce/9lpNcKlDBIr5XAvfdfht1hYfSozxg8cArPjfuSYUOncd/d71DqLGP4fb1o0qzi94BobreH++95n39cMTHm464h/6k0I7rXZe04u0Mj5s1dTb8rX+Sp0V/o/T5vzmrO7tCISy49OmBtg0bpaJrGxPFf8eLz89i75wiJiXb+/dDlxCfYGD3qM72vHh/5KVddNoE5s1eS3SCd3n1OPlD9+6ZcenR9in5XTuRwWOmLw4eK6XflRHp0fYrfNwXKW1U1XQghTmcVgrsAvXv35vLLL9f/f/vtt9OxY8eIZUTgau2kSZOYM2cOAP369ePbb7/Vf5RPnTqVWbNmRT3r1FixYgXLly+H4C3MjRo1QolRD9dutzN//nwIrv/06dP/sAB0LKFBixRFQVVjHo662rVr6wPEhQu/3Xns2LHUrl27ykco+7C0tDRm5m51bN++nQ0bNlC/fn1atmwZPfuE1KlTRw8y5+Tk6AGe6dOnU1paiqIojB8/PubrhAKiI0eOhODxEcr+/fjjj3njjTdQVZVXXnmF7OzsqGcfm8lk0rOsFUVhwIABGI2xRyJWFIWrrrpK/2K9bt06qKE2Yjl06JCe7VinTh3OPff4Bt5QFIWbbrqJ7du3c+edd0aUHGjdujXTp09n5cqVNGrUSK8VXKdOHf2CTk20EStwWZn8/PxjDqpFsHbtBx98ABCzFvfxqIk2KhMemC0vL+fDDz+EYG3a8L6tKrOf4MCBoXIsv/32m55ZUhN9umXLFr744gs4RpA5JSWFqVOn6lnZ4QMdhj4XALp160br1q31/0dr0qSJXhZm69atEec4gJYtW7Jy5UpeeeWViHZq1arFv//9b3Jycrjpppv0Or/Jycl6MP54t+VEPb5oPY2nzKPxlHk8vmh99GwhhBCnMUeclaeeu4577+9NXJyV779dz9zZqzhwoIjefc5k+of/4sKLAt9nbXYzTz5zLf0HdMLn8/P9t+v59r/rOO/8pkz/8F8MurUriUl2du44SHFxIEvXaDTw7wcvY+T/XUV8QqD9pUu2ctkVZzHmqX5YLLG/X154cSsmTr4Js9nI3Nmr2LH9IH36tuft94bSouXRO/UURQmu552cd34z1v22my9n/sru3Ufo07c973x4J+d0jCzFdk7HRrzz4Z306due3buP8OXMX1n3227OO78Z0z+8k959ztS/1x5LQb6TA/sLYz727yvEH6P0A2H9/s+B56MAX3+1hrmzV+Hz+blzWA+eeX4AjrijtYx7X34mF3Vvza6cQ8z4ZBlbtwQGY23RsjZvvzeUPn3bc+BAEXNnr+L7b9cTF2flX8N68Npbg8nMOvo9RwghxB8vZjTNZrMxevRo7HY7WVlZ3H333cf94XM6Ca9ZmpWVxbhx4+jUqRNPP/00BG8fevDBB4974LKTpWkaM2fOjHk787F8/fXXevbs6aimA9s///wzpaWlnHPOOSeV2RguKSmpQjCrpKSE3377DYLB7qouuijB7OFQ1uq3335LSUkJb775JgSPz0svvRQlxkWAc845h9LSwMjF77zzjj49NFie1WrVM6CPJ9M5OTlZX49QQL8m2ohl9erVbNy4EYIlGRo0aBC9SJXq1q3LK6+8wv79+9E0DU3TWLduHTfddBM2m42DBw/qg+KlpqbGPDeebBuhrMnjufAQnmXavHnz6Nm6HTt26Nmw5513XqW36lelJtqoSpcuXYiPj4dgqROn04nVao24mBNeDiWW8PfL3r179Xq9NdGnW7du1UveXHrppRXel+EaN26sZ+9u2LBBP/+Hr39aWhpWa+yBYQi+N0Lbnp+fz/79gR9X4Ww2G3feeSfr1q3Tj7H9+/czYcIEsrKycLvd+sCEiYmJ+uudyLaciOeWHf0cCf9bCCH+aDa7mUmv38zC5aM5v0vln4+jRl+tD+IV7fwuzVm4fDSjRh8tQXTDTRccs82/gtB2hG9btOYtavPNglExlwn1b3jfVbV8iMVi4h/Xnsuns+49OojaTyMZ+X9XUT87MnnD4bAw/N+9+GHJYyxcPpofljymL1e7TjKz5t3PG9OHEB9/9M4so9FA7z5n8sXcEfpzHhp1JZ0vaFbpuilA+3Ma8s5Hd0Y8J5SBG61Bw3Sem/hPfb2++Wlklcun10rgoVFX8s1PI/X2n5v4Txo0DHz/DYnVp+HHcVWPz2bfR2pa4DtUrGPa4bDwr2E9mD3/Af05n866l+sHnl+h3m9KahxPPN1fXy4UcCfGtoTa+efA83E4AoOah8Raj5DQe+uGm46WTAsdP+HbApCaFs9ns++r0FZl04UQ4nQWM7gL0L59e26//XbuuecemjRpEj37tBdes1RVVSZNmqT30+DBg+nfvz8EMzmHDBkScVttTdu5c6eeqVi/fn3efvttPv300yofoeBDQUGBnsn6v1C3bqCWVXl5ecyBgcLl5ubGHAAoPOssNCjT8T6uuOKKiLaqw+12M2/ePAgGFSvLQi0oKODAgQOVDnQWEr69sYJBDRs2JDExMWJatIyMDD1oGqv+58mKi4s76fNCaH/VRBux/PDDD2jBCx09e/aMGXytji1btujBsePNCo5WWRvt2rXT/9616+jgF7Hk5OToFydCt+jHsnLlSv21unXrdlLBvJpooyqpqakkJAR+GHm9Xn3/nXHGGVFLHh+j0ajv95ru08qmhyiKopfc8Pl8env169fXA9gnwmAwVHouqcrhw4dZvz6QQduyZcuIzOGQY23LiaifcDTLOvxvIYQQQgghhBCnVqXBXUVRePLJJxk+fHj0rNOex+PhgQceYMuWLQAMGTKEvn376vNNJhPjx4/XM8a++eYbXn31VT1gUdPCa2FeccUVDBo0iH79+lX5ePDBB/Xgx8cffxwzM+yP0KpVKwgGdEKBiMrs3btXzyYN165dO31bKqvX+UfIzc1l5cqVpKSkxBxUzul0cskll5CcnEzz5s317NLKbNq0Sd/eDh06VAjw7Nix45gD4u3fv5/8/MAouCaTCYfDwYwZM8jNza3y8fXXX+tZsv369dOnv/TSS3rboQsEpaWlx8z+Dt93oX1ODbURzu1264PHZWZmVpnZHG7JkiU0bdoUs9nM+++/Hz1bp2mafkt7SkpKRJ3ommijQYMGemmBpUuXVnrO0DRNr5cdHx9f5SCCy5Ytg2DA85JLLomefVxOpo1nnnmG+vXrk5WVxa+//ho9O8LOnTv1c1B4mYqOHTvqx/2x3tvhF0MaNGigt1HTfXqs41TTNH1QsvDAbP369fVs4A0bNujvy1hKSkrIyckBID09Xa/Ne+TIEXr16kVKSgr9+vWr8s6DX3/9VW+jsotNx9qWE/FGr7O5OLsWF2fX4o1eZ0fPFkIIIYQQQghxilQa3CX4A/d4BiI6nWiaxsSJE/n0008hGHQbO3ZshR/O2dnZPP3003od2ccee4xFixZFLFMTXC4X7777LgQDLwMGDDiuTMXzzjtPD3xt2rSJH374IXqRP0Tnzp2pU6cOAF988UWl2bsej0cfMC5a27Zt9aDJ559/XmWg+vDhw3Ts2JGUlBQ6derE7t1HR7ytrnXr1rFv3z7atm1LvXqRI88COBwOPUuzoKCAuXPnRi+iKywsZMqUKRCscRoKFsfFxemZiHv37uXHHysfmVbTNGbPnq336UUXXYSiKKSmppKVlVXlIy0tTT+OQuVZsrKyIrL/wvfdG2+8UWl2evi+M5vNnHfeefq8mmgj3KFDh1izZg0E65aGyj4cS2ZmJiUlJXg8HubMmVNp0GzDhg16je1LL700YvCpmmijYcOGdOrUCYD58+frt9VH27Nnj147u3PnzpUOzlZcXMyqVasgmMV9oiUqqEYbycnJ7N69m3379lVZNkbTND777DO9v7p27aofe23btuXsswOBwk8++aTS96sWLE0TCv7/P3v3HRXF1YYB/JktdFiqAoKgYu8NC9YYo8Yae77EklhiCkZjNInGqIkmtlRNYu/Gnth7773E3hBEAUEE6bBlvj/YHdlhQRBQMc/vnDmuc+/cnbrsvnPnvVkD0IWxTwMCAqQA8fr166XBCS25ffu29HlasWJFlC5dGjDuD1Mu+zNnzuT6mXvu3Dkp0BwYGAg3NzcAgJOTExwcHBAXF4fjx4/nuD9SU1Mxd+5cwMIAm/nZlvxo5uuBPb2bY0/v5mjmm7frjoiIiIiIiAou1+Duy8TSI6UvQtY8u87Ozpg5c6b0w1uuc+fOGDx4MGBMOzB48GDcu3dPXq1Azp8/LwUJ6tatm+dcmBqNBm+99ST31KJFi3IMrBYlf39/dO/eHTAGXn755Zds6QpEUcTy5cuxbt06s/kmJUuWRK9evQBjoPrzzz+X8m1mpdfrMWfOHJw6dQpxcXEoU6aMNPBRQWXtjRkUFJTjY+tdu3aFnZ0dYOzZePToUXkVqWe4aTT7999/H/Xr15fK+/XrJ7UxevRoqaeq3NatW/HTTz8BxsevmzdvLq9SIFmP3YEDBzB9+vRsqR/kx65z586oV6+eVF4YbWSVtQdoxYoV8/wYfOnSpdGiRQvAeIPg4MGD8iqIiorChx9+iIcPH8LKygofffSR2U2dwmjDxsYGAwYMgCAIiIyMxIQJE7Kdy6mpqZgwYQJu3boFAOjfv3+ON+FiYmKkepUqVXpqXmNLnrWNVq1amQ0sZgpWZiWKIlavXo0ZM2YAxkHnsn4uaTQa9O/fHzDmih09enS2/QEAe/bswcyZMwHjDbfWrVtLZYWxT8uXLy+t16lTp/DDDz9kO09hvCnz1VdfITIyEgDQu3dvs/3Vo0cPuLu7QxRFjBw50mLv/aioKIwZMwYZGRnZzhGVSiVdL/fv38ecOXOyfV5qtVpMmTJFuoHw3nvvmQXk87otREREr6q85AgmIiIqTopNcDfrYDTLli3DzZs385S3tDBlzbMLAF9//XWuj32rVCpMnDhRCsxdu3YN3377bY4/pA8ePIjBgwfnaZo/fz4AYPPmzdL6vPXWW0/NwZpV+/btpaD5vn37cgwSzpgxI9v7P2368ssvs43wbokgCBg+fLjUe3H06NFo164dDh06hMjISBw9ehRdunTBe++9B4PB8kiwADB8+HC0atUKMJ4fVapUwezZsxEWFoZ79+5h06ZNaNSoEUaPHg0Yc9h+/fXX2XpcP6vY2FicPHkSgiBIqQYsqVOnDsaNGwcYe++2aNECQ4cOxaVLl3Dv3j2sWbMGgYGBUq+7qlWr4ssvvzRbz6xthIeHIzAwEEOHDsXJkyelfda3b1906NABKSkpUCgUmDhxopTfuLAIgoAJEyZI+33ChAkIDAzEmjVrcO/ePRw9ehT/+9//pGPn7u6OcePGmeXLLYw2soqOjpZ6gOaUusESlUqF4cOHw9nZGRkZGWjbti3GjRuHmzdvIiwsDL///jsqVqwoBdy/++47NGnyZCCIwmoDxtQqAwcOBIzncsuWLbF9+3ZERkZi+/btaN68uXT9Dxo0CN26dZO18ERsbKyUuiNrqoL8eNY2AgICMGHCBMDCuR4ZGYndu3ejS5cu6N27NwwGA5ydnbFgwQKUKFHCrJ0BAwZg0KBBgIVr+9KlSwgODkabNm0QHx8PhUKBSZMmZbvhVtB9qlKpMHbsWOlz6scff0RgYCCWLVuGsLAwhIWFYfbs2ahRo4aUw7xVq1bSe5pUrVoVs2bNgkKhQEhICGrWrCntE1MbtWvXls6RDz/8MNs50q5dO+l6mTx5Mrp27YqjR49K2xIUFCTt91atWuHzzz83e5ojL9tCRERERERExUexCe5Wq1ZNenx79erVqFChAgICAnDp0iV51SIhz7Pbo0ePPOUjdnNzw4wZM6SelnPnzsWqVavk1QBjr8O5c+fmaTp48CAePHiAv//+GzD2IjY98ptXlSpVQtu2bQFjz+Lly5dbfHR6165d2d7/adOKFSvy3BPYz88PGzZskAb32bVrF5o1awZvb28EBQVh48aNcHd3lwKzlmg0GqxevVoayO7u3bsYMmQI/P394evri06dOuHUqVOAcTT7Xbt2oWrVqrJWnt21a9dw48YNVK5cGbVq1ZIXS0zB7O+++w4wnlczZsyQUjn07NlTCrK3atUKu3fvhp+fX7Y2RowYgV9//RVqtVpqo0GDBtI+M6XqcHJywubNm7MFqwqLRqPBsmXLpGDT+fPn0bNnT/j6+iIoKAgrV64EAPj6+ua4zwujDRPT9QljD8X8MAW47OzsoNVq8e2336JChQrw9/fHJ598goSEBCgUCvzyyy8YMWKExfQnhdGGKWe3qYfmqVOn0K5dO3h7e6Ndu3bSedy2bVtMmzYtx0A3jL1ATddhfveHSUHaGDhwIBYvXmx2nlavXh3e3t5o3bo1Nm7cCBivyZ07d1q8WaZWq/Hzzz9LPXizXtvVq1fHzJkzYTAYYGdnh40bN5r12s3aRkH3qZ+fH3bu3CmlBDl//jz69OkDf39/+Pv7Y8iQIdKAbb169cK6dess3mzr2rUrFi5cmG2fmNqIiooCAIwYMQLTpk3Ldo5oNBosWrRI6r2+ceNGBAUFZduWrl27YvXq1RbX4Wnb8qxuxiXhZpzl1CpERERERERUNIpNcLdMmTLYuXMnmjZtKs1LTExERESEWb2iIM+zW758+RwDAJYEBgZKPS0BIDg4WMphWRD79u2TBsRp3LgxAgIC5FVypVKp0K9fPyl4sH79eoSGhsqrPReVK1fG6dOnMWPGDLPgXYkSJfDZZ5/h6tWrUt5MGIM1cq6urli1ahWOHz+OXr16mT0OrVar0aBBAyxfvhwXL17MNQD7LI4cOQKdTmeWHzMnarUaX3/9Ne7evYvPPvvMrFe6Wq1GmzZtsHnzZuzYsUMaSElOqVRi6NChCAsLw5gxY8wGf1IoFKhevTqmT5+O8PBwtGvXzmzZwubp6YkdO3Zgz549aNOmjdmxqVq1KmbMmIGrV6/mus8Low0YU2/AeG7nNX1AVu3bt8etW7fw2WefmfUgLVGiBD766CPcunULn376KZRKpdlyWRVGGxqNBitXrsTmzZsRFBQk5e5WKBQICgrC5s2bsXnzZouBu6yyPiXwrD23C9KGIAjo27dvjudpYGCgdE1mTT0iZ29vjwULFuDChQvo1auXWe9hPz8/fPPNN7h7926uN7gKY5/6+/vj0KFD2Lx5M9q0aWO2Hi4uLujVqxeOHz+OFStW5NiOaZ9ERkbim2++Mbt5Y29vj169euHChQu5/o3x8fHBwYMHsXjxYrPPS9Pnx549e7B69Wopt64luW3Lsxh76BIqzNmGCnO2Yeyh53PTlYiIiIiIiABBtNRVk+gltGnTJnTq1Akw9lbr2LGjvAoREb0Atj/+jTRd5s0VG5USqSO6yqsQEeXbsbM/QK/PTD9WXLk5V0JKWgzqVv9EXvRMLlyZi8Tkou/cQvQq0jj643FiKOzsSiAlJRpOjqWRkJj55JWjgw8Sk+7BRVMGcY/vwMbaFQqFEimpMXBxLo+4+JvQOPrhcWJmCisrtQPUKlskp8bA2aks4hNCYG/vifT0eCgVNhAEJXS6JOj06XB08IVOl4zUtEeyNcrkaF8KyanRMBi00DiUxuOku3By9EVCYjiUSitYqx2QkvYIKpUtrNQOSEmNMVve1sYdqWkPoXHyx+OEzM5ajvbeSEyOgLNTGcQn3DGrDwBWantkaJOhUKigVjtApbKFACApOXPsCGfHMohPvAM7Ww9kaJOg06VC41QGaemx0OlSoddbTjXp5OiHhMQw2Nt7IzX1IQyGDDjZ+yAh+R40jmXwOPEObKxdACiQlh4LK7U9BIUa6enxcHYqh/iE23C090Faeub2arWZ+xAAHB1KITk5CgZRDycHPyQkPUknplY5QK22Q0pqNABAqbSGXp8OBztPJKVkPhnnrCmL+MchcLDzRGraI+gNmX9fTPvNycEXicn3oBBU0Bu0sFI7IEObJJ03pnpqlR0EhQp6XRr0hgw4OfoiJfUhdLrMpx1dnMohLuE27Gw8oNenI12bABensohLCIGVlQMgKJCRngAnB18kJGUOkKxSWcPGyhVJKZHSMbOxdoZCoUZKagysrByRkZEIBztvKBRKJCSFQ6mwkrYBAOztPKHNSESGLlk6J63VTsjQJUMU9bBSO8DaWoPEpPtQq+yg1WUORm1aNiU1BqJo6ixlK22P6TqxtXZBanocAMDZyR/xCaGwtnKEoFAjLe0RXDTlEPf4yeDV9rYlkJoeh8Z1c34CGwCuh6xDTOwlONp7ITE5EvZ2nkhOiYKLJgBxj2/ByaE0EpLuwt6uJJJTHsDG2gVKhRWSUx9I1621lSMUSiukpsZKx9x0DpuOu421M9LS46FxKA0DRCQmhcPG2gVp6XFwcSqHxJQIiKIBen06HO1LITH5PqqU/x9cnfP+9Or5y3OQlJJ5DVlbaaT9GVhrBGJiYpCcdhP3onbB2soJGRlJsLMtgeTUKPh6N4FfqcwnmZ9Vsem5S6+mCxcuwNvbG5UqVZJ6RlsiiiIOHz4MAHB0dJRGoCciohfPX/Ok52/W10RERERERFS0GNylF6pUqVJwcXHB9evXMW3aNMTGxsqrAACOHTuGOXPmAMYBxbI+3k1ERC/WvHb10Nq/JFr7l8S8dpn5gImIiIiIiKjoMbhLL5Sbmxu6ds18fPfUqVOoU6cOZs+ejbCwMERGRuLkyZMYOnQoWrRogfj4eCgUCowZMwYODg7ypoiI6AUJKuWGnb2aYWevZggqlXvecSIiIiIiIio8r3xwd8qUKRAEocBTv3795E1TIRAEAV9++SX69+8PALh79y6GDBkCf39/eHt7o0GDBpgxYwa0Wi2cnJywefNmtG7dWt4MERERERERERHRf84rH9yll5+9vT0WLFiA48ePo1evXnBxcZHKFAoFqlevjunTpyM8PBzt2rUzW5aIiIiIiIiIiOi/6pUP7n7xxRcQRbHA0+LFi+VNUyESBAENGjTAypUr8ejRI2m/6/V6/PvvvxgxYgScnJzkixER0UviamwCrsYmyGcTERERERFREXrlg7tERERUtMYcvIQq83agyrwdGHPwkryYiIiIiIiIigiDu0RERFQgP568bvE1EVFBiKIIADAI8pLiQ9Rr5bMKRK/PgJC5W4jIklyuD60+81/pujToYRCMHzCGzEKDQQeFCBgMWojajMwyvU4qAwC9IEAU9RB1xnIjvT4Dgk6fWU/35No36LXQmdoVhCfvaSo36KAwft6JWd5LEAFRNCDD+L6iqH+yTlkYDFoIYuYyIjLbFo3L6GDcLuN7mt5f1GdpRzQAOh2gf7I/DGLm+uoNWgi6zLZgXCeTrNshml4b118U9RBEAwBAh8x/n+xbHTKQZTuNy0j7WdRBFA3QGbfBxGDQQTBkroBONB5M0zaJemgNT/a5tP0GLfSyY6w3ZGRus4n0vnrAuM+R5TyxdExg0AMGYxsGvbSNQOY+M51DyLLvTHUFfeZ6mLbBIAgwGIxtyvaTzrhNojZdKjOdW6b1NNHrM4xH/8m2wqAHTOcWDBCN22AQ9WaXiqjPkLZABKR9L+JJWwaDXtqXWddROtdNxxGAXlBAb9BCl4e/33rT+Wu8nuR/NzPPpcztk649feY+NpWJoh6GLNd1VlrRdC48OY6m42LIeh4ZdE/OBV3m/gYAvXH5vNAbMgDTNahLh8Gghy7LtWa6Jgym42rcZq3+yfs9K0E0fWsiIiIiegZV5u2QUjJUdnPClYFt5FWIiPLFYNAhat9vcG3UBzH6eOyLPYyeXm/Jq730Eu8chTbiGjybB0OhVMuL88Vg0CH6yFw4VH4dKkcPREdHw8nJCTY2NvKqT6X+5yYUaeY/gIleBTp1Ou74HYdCZS0rEXHboSWuKWrgXftj0N45C6V3BYSl3sMVJwXa6XxgiAmFwdoOamsnZETfhtq3GrT3r0Lh4gmVAdjnmoHHhiRUtasEv8tnoC5dE/r7VwGNO1QKa2jj7sOqdG1k3D4Gdbn6yLhzBionT+hFLWIz4nDGyx7VNdVxPv5fWCmt4G3jBZ1Bi2oh4bDxqoy0O6eh8qkKMeEhDPo0wKcatmovopFTTZxOuozmt2NhW7oW9BHX4dLwXcD4uRB/bAmcanVC4uUdiC9dCScNoWiPctCFX8GNcv4I1UahgUt9nIw7g5buTfEgPQbRsddQ615mOkRBZQVRoQSUAm7b6BFXwgv1o9OgttYgI/oWbMsGIvXWMSi9K8IQex9iRjKO+TgiwKMWzsafh0JQoqFzPZyMP4PWya5QpKVAn5YIa8+KSAs9jbtVayFaG4/6D1JgbeeGTXaRqGFXHheTb6HZnVg4+tVFeti/UHr4QqEzQJeWAEPKYyRWCcSJtKtQKtQoYVUC1W+Hwdq7GtJCTuB+tboIyYiEWlCiskNFaM7uRVzVQJxJu4meXm8hIyESSRe24GGNRjibfA06gw4dUtyh1BugS4iGtmxt7Mm4DJVCjTdS3KBIS4Y+PQmGUpWwXX8VOlGHnvZNkHJ5F5ReFaBIToBelwYbrypIDTmeGczW6yHqM2BVth5iY2/hmIcSDioHNHkIqNV20MbezTyH7l6A4OEPxIRBr8+ATbmGyLhzBncrVcP1jLuo7VQDbucOQOVbDYboOzDYOsDKxhkZD24irmJtnMkIQVmdPcqGR0KwtoVg7wJFRhq0CdFwC+oHKDL/tqyOOQOnuGg0jH8MuHhBjAnFntJ2sDMo0TDsEaBUAUoV7Ms1xrr00+jm1QXWCisAQFrMDUSFHMZRX0d08+yMxIw47Hx0ED293kLK3XMwxNzBgVI2SDGkw93KDfUik6G2d0NG1DWoy9SFLvQ8FB5+ENJT4VStHR6kP8aauMvo6OSHkMhtaFflc7Or0UQUDYg5tRzWnpWgM2Qg484pqPxqQR9+GXAuAZWogDblEex8a2Gj7iKauQVBdWozFEoVHCq3RtLNA7DzD8Q67Rm0VVQCQs4BAmCl8cZpdyXUKntYQ0RIWgQq6Bzhd+8+BBsHCLZOMCQ8BEQ9RL0Ogqs38CgCNhUaI/XaIdiUD0LarWO4VVqDip7NUdIxQL7q2Zj+RkeWq4jLGeHoZFsPSRe3QVm7Hc5F70Bd9z7QR5+CGHMLBn0GBEEJqzJ1obv7L2LKlYVeMKBWqQ7yZvOMPXeJiIioQOa3q4d2Zb3QrqwX5rerJy8mIso3hUIFGHSwttLAycYDAGBj7VzsJnvHUhC16ZnbU0AKhQoKQQkrtQNsrJ2hVjnC2kqT7T3zMiltbCAkZnDi9MpNar0dDPp06NITZFMiHG0dYYACjs5lIQJwcCoNG6ghQoSjxg+CoIKtXQlY23tAgAAHp9JQCArY2XtCbaOBndoJelEPldoWEI3LKFSws/eClZ0bBIUadg4lYNBr4eDkC1GbDrW1E2ztPGCVng5RAOysNNCKWigUKlipbGFvpYFKYQU7By/AoIejpjRUKhuoFNbQWLlCJ+rg7lAGqfo0KEQBjk5+UChU0rVsZ+sOiCKsrTRQKNRwVDtBJ+rhqCkDCEpo7L2Rqk+DWmUHQRBgpXbI3A5rO4h6PWDQQ6lQQ6W2g7W1C6wMgFJhBRsrDWzsS0CAAvaOPhD1ejg4+kBt7ZC5XYIa9mpnZBi00Bq0cLB2hUEU4eDoA5XKBkqFCnYOXhD1erjYl4JaZQNrKw2sHUpCFAFXe1+kGtKggAL2jr4QANg7eENto4G1nSugS4ervQ/SDFqoBDXs1U5QCFaws/eCaNDB1cEXSoUKEBSwVjsCogGudj4QjX8rHF3KQjTo4O7ojzRDOhQKBeztvWBt5wal0hqutp7IELWZ54FxnVVqezhbu0Nr0EIURTg6lwUMBjg4+kKpsoHKyhFWageIeh0UCmuIxt6dDk6lYQcrZIhaWKtsYWdfEtZ27lAo1bB39AFEwMGxFPQZKcbzyheACFfH0kjVp8FK7QAY9HDS+ENQqGBr7Qwbh5IQBAFuDpn7yd7BCwJEKBVWsLXzgFJpA+gz/0aazgWNtQv0VrYQFEo4OJaCISMFgBI6K5vM3rAioBSsYG2lgU7Uw9rKSVrW3skXNjoRWoMW1lZO8HDwM77WwFHjB4WghJ2VBgpBBaXSCjZWjrB19AREEU4af4iiCHsHb6hUNrCxdoa7vSe0ogEOtiWgl/XAzkoQFFCpbaFW28PR2R8QASfnspn7zt4TVnYuUCmsYWXlCEGhhLWVEwTRAEFQwtpaA6XCCtZWmee8vaY0xIwUKNV2UNto4GjrAUcrDTR2Xkg1pMHJyQcwGKBS28PWzgOCQgWlyhaCcd0N2jTYOfhA1Gth7+QLiAYoFEr5KudIoVBBEBSwVTlCEATYO2W25WRXEgZjb2K1tSsM6clQqmwhatPgaDzmNlYaWKsc5E3mC4O7REREVCCNSrlha48m2NqjCRqVcpMXExERERERURFhcJeIiIiIiKgYYD49okKSJTulaLyyBFku3Bfr2a9203Y8dXNyqiA829tL+0+0vHxOb/dSsLC+wJPc75AyqT4HOaxL1nP2hRLFbKsimPLMPsMqmq4/83lZPLcdLyPmfCgKV+G8C4O7RERERERERERERMUQg7tERERUYJdiHuNSzGP5bCIiIiIiIipCDO4SERFRgXy5/yKqL9iJ6gt24sv9F+XFREREREREVEQY3CUiIqIC+fn0DYuviYiIikJSeipa/PYp7D9vgzPh1wEAZ8Kvw/7zNmjx26dISk+VL/JSOXX3GiZsW2Q2b8ruvyAMbY5Nl46azX/RrkSFov2sL6D8tCWEoc1RckwXTNq5FMkZafKqRM+dlF/4hSVmJeSQN/d5KEie7IIs+zJicJeIiIgKpLyLo8XXREREZO7C/VtoNWM4QmIj5EUvnXXnD6D6D+9h65XjaFquBt5v+CYcrW3x9eZ56DxnNB6nJssXoTzKOlDXi/BqhbX+e15UMJVeXgzuEhERUYEseLMeOgR4oUOAFxa8WU9eTEREREY6gx560SCf/dIJj4vGl5tmw0ZthcPDZmL/0F8x/39f4OqYpRjUuAP23DiDn/evli9GREQvAIO7REREVCCBXq7Y1K0JNnVrgkAvV3kxERERFTPh8dHQ6fV4u24rNC5TTZqvVqowvGVPuNo5YcvlY4hLSTRbjvLmRT8Szn6fxRvTUJAcg7tERERERET0wl2JCkWH2V/C4fM2EIY2hzC0OfzH98SMg38jVZsur15k5Hlmq/3QH3tvnMXkXcst5sVNSEvGFxtnodTYbtJ6O3zeBoNWTMO9+BipXr9l36PetMFIyUjDkpM7IAxtjn7LvjdrS2/QY/W5fag0sQ+Eoc2h/LQl2s/6Ajeiw83qmVyPvov3lk+G1fBW0vsO+3sGEtLMUyaYchJP3LEE0/asgNXwVlB+2hLvLZ9sMX9u4zLVcGf8Ksx7e1S2QKSjtR1s1VaIToxHmi7DrOyVwdgZAOAFx6DpJccg88uDwV0iIiIiIiJ6oUz5XbddOYFWFetiUOMO6Fy9CSIex2Lo2l/x2T+/Q2fQyxcrdH9fOIhaUwZIeWb7BrZBdGIcWv8+An+d2S2vjrBHD1Bv2mBM3b0CJRycMahxB/QNbAONrQPmHduM1r9/JgV4m5WriS41mkIQBPi7emJQ4w5oVq6mWXsfrv4JvRdNgKONLd5v+CbKuHli65XjCPr5Y1yOvGNWd8vlY6gzdRAWndiGqp7+GNS4A8p7+ODX/WtRb9pghD16YFYfAH7atxpfbZqDdlUaoEO1Rijv4QN7Kxt5tVydDLuKiIRYVPMqAxdb5tonInrRGNwlIiKiArsQ/RgXoh/LZxMRET3Vg8Q4fL1lHpxs7HDssz+wYdD3mNN7JNYPmoQbY5ejrJs31v97CHfjsgcrC1N4XDS+2DgLeoMBa9//FvuH/orF745G2ITVeK9hO1yMCDGrL4oipu1ZgZsx9zCl0xCcHTUPc3qPxOJ3RyN0/Cr0b9AO1x7cxYFb5wEAAxq1x9dt+sBWbY1mATUxp/dIDGjU3qzNhLQUbP9wGk59PkfKcdu/QTs8TH6MNef3S/XC46Ix7O8ZSNNmYOE7X0rvfXbUPEzpNAQ3Y+7hk7W/ZOvxHJ+ahOV9x2LDoO+xYdD3GP3Gu2bluUnTZmDJyR3os3QSVAolhrfsCRu1lbwaERE9ZwzuEhERUYGM2v8vai3ciVoLd2LU/n/lxURERLm6EhWK+NQkvFWzKeqXrmRW5m6vga9LCSSkpRR5fteDty/gVsx9fNikM7rWbCbNt1VbY1zb/gjwKGVWPzYlARcjQlCxRGm8U7+1WfoCtVKFSiVLA4BZaoanea9BO7Su+GRwUrVShXfqZbZ9JzZSmr/h4mHcirmPnrVb4t36b0jvLQgCgpt3RZvKgdh59RTO37slLQMAlUv6oVXFumbz8mLgiqmwHdEa/ZZ9D51ej80fTMbrz9AOEREVPgZ3iYiIqEB+PX3T4msiIqK8aFm+NiIn/oMF//sSyRlpuBlzD2vP78dn//yOhj99KPV8LWr7bp4DALSpHJgtz6yPswcC/SqbzXO31+DAp7/h2tdLUcLBGfcfP8TOa6cwdfcKvD7zM4zdMt+sfl6UdffO9t4udg6wVVsjOSMNOoMeOoMe+437pEftFlAplGb1bdXWaFauJjL0WhwLvWxW5u/qCQcrW7N5T6M3GFDapSQGNe6AluVrQ2fQo80fn2PijiXQ6nXy6kRE9JwxuEtEREQFUsHlSb69rK+JiIjy6kZ0OBr99CEcR7ZFhe/eQY8F4/DzvtWITX4Mu3zmhH0WyRlpCI2Ngp2VDbw1bvJiCIKAqp5l5LORnJGGYX/PgM1nreEzthva/PE5vtg4C8dDL8PdXiOv/lTlPXzksyQPkx4jTZuBNG0GHiY9hiAIWHhiGwavnJZt2n8rM1B94b55z113B02+UykoFQp807Yf5vQeib3Bv+DS6EUo6+aNb7YuyDa4HBERPX8M7hIREVGBLGxfD53Le6NzeW8sbP/kUVIiIqK8uBx5B0E/f4zjoVfQvVYLrH5vAkLHr0bqj7twc+xf2VI1FAVRFPM9YFuqNh29F03Ar/vXorKnH37vMRwXv1qEuClbkDR9Bz5t0V2+SKESRRGbLx3F3KObs027rp2WVy80lUv6YXjLHhBFEVuvHJcXExHRc8bgLhERERVIPU9XrO8ahPVdg1DP01VeTERElKu/zuzGw+THGPX621j93nj0qN0Cfq4l893DtCAcrG3h51oSKRlpiHgcKy+GKIq4HHXHbN6Z8BvYfuUEavuUx4Ghv+Gjpl1QzasMnG0dzOoVNhu1FdwdNLBVW+Pk57Mh/nYgx2nxu6PlixcKP1dPAGBaBiKilwCDu0RERERERPTCmAYca1K2RrZ8s6GPorKlFigqjctUAwDsuHoSoiialUUnxeNs+A2zeXEpidAZ9KjuXRZu9k5mZanadCmHb2FTKZSoX7oSUrXpOBF6RV5cKGYd3gDXLztg/LaF8iIAwJnw6wAAjY29vIjyQH5+PW/mVxkVNyJe7PlDLx8Gd4mIiIiIiOiF8XH2AAD8feGgWU/QRykJGLr2V8SnJmWpXXTerNIQAR6lsPDENuy6/iStgVavw9gt83DtwV2z+i52jlAplDh8+yJCH0VJ8/UGA37ZvwY7rp40q5+VKX/us+pWszm8nNwwdst8HL1zyazsUUoCXpsxDFbDW2HFmT1mZXlVs1QAktPT8Ov+tTgZdtWsbPf1M/h1/1o42zrgvYZvmpUREdHzx+AuERERFdiZqDiciYqTzyYiInqq7rVawNnWAYtObEPlSX0w4K8p6Dx3NDy+6ozIhFh0qh6UY7qEwuTrUgKTO36ANG0G2vzxOVr89ikG/DUFlSf1wfxjW2FnZQNBEGCrtgYA1CoVgJYVaiMkNgJVJvVFr4Xj0W/Z9/D6+i2M27oQ7zVsB5VCiStRodJ7eDq5wcNBg+1XT+LdpRMx5+imZ+rFGeBRCjN7DENyRhqCfv4Y9acPxuCV09Bl7hh4jnkL+26eQ6sKdfBmlYbyRbPpt+x7CEObY8ruv6R5Df2r4Lv2AxCfmoQGPw5Bi98+xeCV01B7ygC0/v0zJKSlYFavEahVKsCsLcobeQ/15y3/Zxy9TAT2vSYZBneJiIioQD7f9y/qLd6Neot34/N9/8qLiYiIclXbpzyODP8dTcvVwO2HEVhwfCtO372OHzoOxokRs9CzdksAwPHQy/JFC123Ws1xYsQsNC1XAwdunceC41vh4eCMg5/+hu61msNWbQ0Xu8ycug7Wtlj7/rf4tEV36Ax6rD63DyvO7EH7qo1w8auF+LpNX7g7aHAlKhRxKYkAAG8nN0zpNARu9k5Yd/4Alp7cgaSMVNla5E3Xms1w/ov5aF+1ES7cv425Rzdjw8XDqFDCF4vfHY2/B06ExvbZ0iYIgoCRrXrjwKe/IahsdRy6/S/mHt2My1Gh0vb1qvPaCw9SEhERg7tERERUQL+dvmnxNRERUV5V8fTHwU9nSAOB3f9uHUa9/jacbOzxTr3WEH87gEkdBgHGoOr+ob8iefoO1PWtCACo61sRydN3YP/QX+FgbStrPX/qla5oti7HPvsTNUsFIOzRAzjZ2MHFzlGq62Rjj1+6BiPj5z0QfzuAjJ/3YOE7X6JiidIo6+aNyIn/4NTnc6RlBEFArzqvIfr7DRB/O4BDw2bC0doOX7z+P4i/HUDHao2zrEmm3Latiqc/Nn8wWXp/8bcDuPTVIvQNbCP1MM7ahqUB1ha/Oxribwfwxev/M5svCAKalauJw8NmQv/rPmn7Nn8wGVU8/c3qFjsMShPRK4TBXSIiIiqQyu5PBpHJ+pqIiKg42XTpKKyGt8L3O5eZpUoQRRFbLh/Dodv/ol7pSvBycjNbjoiI6EVicJeIiIgKZNGb9dGtog+6VfTBojfry4uJiIiKhdo+5VHO3RtjNs9F5Ul9MeCvKcacu33Re9EElHJ2x3ftB5j1iCUiInrRGNwlIiKiAqld0hlruzTC2i6NULuks7yYiIjouUpKT0WL3z6FMLR5nie/cT2hVCiwN/gXfNayJ+JSE7Hg+FYsOL4VcamJ+KxlT5wYMYsDiBER0UuHwV0iIiIiIiJ6ZdiorTC+3XtY8/6EPE+zeo2As60DvJzc8ONbH+PBpPVSDtsHk9bjx7c+ZjoGIiJ6KTG4S0RERERERK8MlUKJFuVroXutFnme2lVpwHQLRERULDG4S0RERAV2MvIRTkY+ks8mIiIiIiKiIsTgLhERERXIZ3svoMGSPWiwZA8+23tBXkxERERERERFhMFdIiIiKpCZZ29ZfE1ERERE2YmiKJ/1XAnyGVSsiHix58/L4kVfRy8TBneJiIioQKq4OVl8TURERETZCcKLDa8yJFa8CQzPkwyDu0RERFQgSzoEomclX/Ss5IslHQLlxURERERERFREGNwlIiKiAqnhocGqzg2xqnND1PDQyIuJiIiIiIioiDC4S0RERERERERExRJz0L4Y3O8vDwZ3iYiIiIiIiIiIiIohBneJiIiowI7ef4ij9x/KZxMREREREVERYnCXiIiICuTT3ecQtGwfgpbtw6e7z8mLiYiIiF4YPjieD2L+95b4DMu8LJ4trcCzLENUtASxOF+JRERE9MKpp62FzpD5dUKlEKAd2V1ehYgoXwzaNMRf3grnSq0RjSSEpYSigWsjebWXXnL4GeiS4+BYNggKK1t5cb7FXdwEO69qsHYvg8jISDg7O8PW9hnaPXQbsFaZz9MpgBANoDSYz3/eHLRAyRT5XKI8uaS6iQcumU8SOcEGMUgBIKKU4IR4XVVo7coiUBUCMS4Uagd33Eq6hRT3kqiSZg/D4wdQqKwhqKygS3oItZMnDBnJgEIFGHS45GqFyLT78FA6o1p0MqzdykCX/AiCUg1AhC4lHlYuvtAlPYTK3gWGjFSI+gwAAmLEJETZK+Hg6I2IlHtIM6TBx84XoiiiQuQjqB08oE9PhMpWA11KPER9BgxOHrhml46S1p6IT42BJvQqSrpXhahNgVOFloDxs/Lxle2w862N5Pv/ItbRHjEO1qiuc4YhIRohHhqk6FPgYuWMyNQoVHCqgPiMeCQlRqByggB9WgIgCIBBD1EUcUejQrrGDTUeKyAoraBNjIatRzlkJDyA0toOGY+joHb0wHHxLkq4V4bWoEOGIQMeNu54mP4QlRMAIT0VBl061PYeMGQk4W4JFyTok1EtTge92gpnbJNQ0qYkMgwZ0Ny+BE+PajCkJ0JQqCAa9AAA0aBHnKsbHqp0CE8Jg7etD8pHPISVY0no0x4j0rMkEgzJSNWnooTKBSXD7uCxb3k8QAIaujWGIS0BiaEnEevlg9D0e4jLeIQ2GT4w6LXQpyXC4FEaodbp0Bq0qJQgQJGRAoMuA6KTO27Y6RGZFoFWTg0hRlwF1HYQtSkQdRlQ2blAmxwLUa+DwsoWuoQHsC1VDbEJ4Yjx8oJe1KPyowxAUECb/AhWGk8Y0pMgKK2gS0uC9nEE7H1rQZsch8iSHnikT4BGYQfvu2GwcfODLi0BgABBoYQ26SEee5XGI7UeqpQk+CeJMKQnAgoVBIUShowUOFd+AwprBwDAvrircHgUgYp6JRRWttCnJeKM9WM42LmjQmwa9OmJgGiAskQATqqi0cg9CI4qRwBA+sMQxMbeQKiLLaq71gFEERcT/kUtTW0g+g4y4u/hSgkHPMx4CDcrN1R/pIOgyjw/rN0DoE+OARRqGHTpcK70OuK0yTiTGIrSVvbQJt1CBY8mUCttpOs0q/jL26ByLAG1nTPSH4VB5VAS2qRoKBQqiKIB+vQkWGm8cNwqFgG2fnC8cxkGbSoc/BogNeoqBGcvnFTHoI5QCtZJj5Eacxs2rqVx1c0WCgiwUlpBJ+pgk5wIn2TAkJEMQaGAoLIDDFoICiVEQQFRr4OVSylo4yOg1nhDmxCJO3ZJcHHwg59LLflqWxR3cTPuuTkh2VqFetblkXzvAgSvAISn3oYLKsA6IQxWSiAt5jasnEtB7eQJbWI0HrhYwwA9KpVoJm8yzxjcJSIiogKps2gXzj2IBwDULumMs/1by6sQET0TURQhCIJ89n+aaNBBUKgKFtzNyVQA6fKZz1kfAP7ymUR5czPxBu6k3EZcRjwauDbErgc74KByRFP3ZvCx85VXz5UoGiAI2R92zjCkw0phbbFc1OsgKFWAaADkZcZr1xJLZaJeawwcZ52Zvd2s8y0uA0Br0EKtUEMv6qEUlOaFoghRNACiweKyokEPQSFbJkubOZH2hQUG0QCFcTtyat/StlpqUyfqoBJUmb2O8/j3wtL+Rk77PAuDLgMKlZV8ds5EPZBlfxv0WgiCwmx7n+xHEYD5+lvaNzmtuyXy7RENekAUIShVECFCkL0fkNmb2dL8nFi6Dkz0MECZh4QB8vWUM+i1UCiznL+iPvNwK5TSsvlZ7xzfz8I5pxd1UAp5298AAIMeyHLMTPsnJiYGNjY2cHTMDKgbC/N8zj7N0/cyERERUS6WtA/E25V98XZlXyxpHygvJiJ6ZgzsZpfXH/XPhLubSJJTwMpKYQ3kUC4FHi2V5XLtWiqzGHyy0C7wZL7FZQApCJstsAsAQmZP0ZyWlQcXTXIL7CLrvrDAFNhFLu1b2lZLbapMgbd8/L2wtL+Ry/4zyVdgFzAL7AKAQqnOtr1P9mP29ZfXzZxned0tkW9P5nHOXD6nQGhO83Ni6TowyUtgFxbWU05hLJfOX0Ep7RvTsvlZ7xzfz8K25CuwC5gFdvGU/ZOfc/ZpcnkXIiIioqer5qHBX50a4q9ODVHNQyMvJiIiInqOCi9gQkRUHDC4S0RERERERESvBGaeJKL/GgZ3iYiIiIiIiIiIiIohBneJiIiowA6FP8Sh8MzRqYmIiIiIiOj5YHCXiIiICiR41zk0+2sfmv21D8G7zsmLiYiomJpyaAqEb4RcJ9cfXNF7dW9cibkiX5yIiIieAwZ3iYiIqEBmnQ+x+JqIiF4N/s7+GFR3ULapZ7WeEEURqy6tQtDcIJy8d1K+KNELxfy7RPRfwOAuERERFUgNDyeLr4mI6NXQzL8Z5nSek21a1XMVIkZGYECdAYhPi8c3e79BqjZVvjgREREVIQZ3iYiIqECWdWyAPtX80KeaH5Z1bCAvJiKiV5it2hYjgkbA1dYVp+6fQmh8qLwKERERFSEGd4mIiKhAKrs5YUn7QCxpH4jKbuy5S0T0X+Ns4wwHKwek6dKQok0xK4tMjMSI7SPg8r0LhG8EWE2wwttr3sa9hHtm9YiIiOjZMLhLREREREREz+xC1AWEJ4TD39kffs5+0vzzUefRYE4D/HT0J7jZueH9Ou+jsW9jrLy4EtVnVmeOXiIiokLA4C4RERERERHlW1JGElZfWo0+6/pAFEW8X+d9uNu5S2XDtg5D+ONwjGs5DleDr2J+l/nY//5+rO29FgnpCXh//fuITo6WN0tERET5wOAuERERFdj+uzHYfzdGPpuIiF4BS84vgfCNkG1ynOiIXqt7IVmbjBntZ2BYo2HSMofDDuNg2EE08GmA4Y2GQ61US2VdK3fFh4Ef4nL0ZewN2SvNJyIiovxjcJeIiIgK5OOdZ9FyxX60XLEfH+88Ky8mIqJizt/ZH4PqDpKmWp61AACutq5Y23stHo9+jE8afAKlQiktcyjsEERRRKdKnaCx0WRpDRAEAa+VeQ0AsPv2brMyIiIiyh8Gd4mIiKhAZp8PsfiaiIheDc38m2FO5znSdPbDs1j41kLEp8Vj4PqBOBd5Tr6INGDazls7MXjD4GzT0gtLIQgCbj26haSMJPniRERElEcM7hIREVGB1C7pbPE1ERG9mgRBQL9a/TC62WjEp8Xj3XXvIiw+TF4NAHAg9ADmnpmbbVp/dT1EUZRXJyIionxicJeIiIgKZHnHhuhf3R/9q/tjeceG8mIiInoFCYKAYY2GoX6p+rgZexMjd4yEVq+Vyn2cfAAA63qvg/itmOO0//39cLByyNIyERER5QeDu0RERFQgFVwdsPDN+lj4Zn1UcOUPdCKi/wo3OzdMen0SFIICa6+sxabrm6SyBj4NAAB77+xlD10iIqIixOAuERERERERPZMW/i3Qt1ZfiKKIb/Z+g+jkaABAUOkg1C9VH3+e/BOrL682C/CmalMxcP1AKMcpMWHfhCytERERUX4xuEtERERERETPRK1UY0zzMfBy9MLl6MuYcXwGRFGEm50b/ujwB1xtXdF7dW9UnlEZA9YPQK/VveA9zRvzz85HZY/K6F+7v7xJIiIiygcGd4mIiKjAdoc+wO7QB/LZRET0HxDgGoARjUcAAGaemInzUecBAPVK1cO5j85hYN2BuJdwDwvOLsDqS6uhsdFgyhtTcHTQUfg5+8laIyIiovwQRCZAIiIiogL4cOdZzDp3GwAwpHY5/PlGHXkVIiIqZJGRkXB2doatra286NlNA5Amn/mc9QHgL59JlDc3E28gJPk24rXxaODaEDujtsNR7YSm7s3gY+crr05E9FzExMTAxsYGjo6O8qJCwZ67REREVCBzL4RYfE1ERERERERFi8FdIiIiKpC6JV0sviYiIiIiIqKixeAuERERFcjyjg0woEYZDKhRBss7NpAXExERERERURFhcJeIiIgKJMDFAfPa1cO8dvUQ4OIgLyYiIiIiIqIiwuAuERERERERERERUTHE4C4RERERERERERFRMcTgLhERERXYjjtR2HEnSj6biIiIiIiIihCDu0RERFQgH+w4g7arD6Ht6kP4YMcZeTEREREREREVEQZ3iYiIqEDmXQix+JqIiIjoRRIEQT6LiOiVw+AuERERFUh9T1eLr4mIiIiIiKhoMbhLREREBbKic0N8UKssPqhVFis6N5QXExERERERURFhcJeIiIgKpIzGHrPa1MWsNnVRRmMvLyYiIiIiIqIiwuAuEREREREREb1yRFGUzyIieuUwuEtERERERERERERUDDG4S0RERAW2NSQKW0Oi5LOJiIiIiIioCAkin1MgIiKiAhi0/QzmXQgBAAysWRZz29aVVyEiokIWGRkJZ2dn2NrayouemfirATFu92CfppEXFRoBArRiBjQp7vIiQAQQBKC+vIAoby7FX8St5BtI06ejlnNtHIs9AhulLeq7BCLAsby8OhHRcxETEwMbGxs4OjrKiwoFg7tERERUIMqpa2Ewfp1QCAL0o7rLqxARUSEriuAuAFx7dArno/fJZxeqdjfeg+ach3x2pg8AeMpnEuVNqj4VAGCrtIXOoINKoYLOoAMAqBQqWW0iouejqIO7TMtAREREBdLA29XiayIiKn6slYUbLLZEb6WXz3qC8TcqAFulLWyN57ApmKtSqBjYJaJXGoO7REREVCArOzXER7XL4aPa5bCyU0N5MRERERERERURBneJiIioQEo72eH3N+rg9zfqoLSTnbyYiIiIiIiIigiDu0RERERERERERETFEIO7RERERERERERERMUQg7tERERUYBtvRWDjrQj5bCIiKqa06TpM+WAh3q78BU7uuiQvBgDcvx2NQQ0n4O3KX2DW6DUQRVFeBanJ6Rj/7p94t9pXuH4mVF5MREREBcTgLhERERXI+1tPofO6I+i87gje33pKXkxERMWQ2lqFyvXLAADuXL4vLwYA3LlyH0mPUwAAt/8NR1J85uusHkU9xv1b0fDy94B3WQ95MRERERUQg7tERERUIIsvhVl8TURExVvF2v5QKhW4deEu0lMzzMpEUcSlY7fg7O6I8jVLIzI0BhEhMWZ1ACAq7CGSHqegXA1fODhz0E0iIqLCxuAuERERFUijUm4WXxMRUfHm4eMCl5IaRNyJQfLjVLOypPgU3P43HF5lPNC0c13o9QZcP5c97cKtf8MBANUaBUAQBHkxERERFRCDu0RERFQgKzs1QHDdAATXDcDKTg3kxUREVExp3B1RtpoP4qITcP9OtFlZREgMIkNjULGOP6oEloWDxg5XT92BNl0n1UlLycD1s6Fw0NihTJVSZssTERFR4WBwl4iIiArEx9EOv71eG7+9Xhs+jnzklojoVaFUKlCjcXmIooi716PMyq6fC4XBIKJKg7Jw83KGbwVPhF65j7joBKlOQmwSIu/EoISvK5w9HM2WJyIiosLB4C4RERERERFZ5BNQEkqlAjfOhUGvNwAAtOk6XD11B25ezihVtgRs7KxQsY4/4h8mIvzmkyBwxJ0YxD9MRPlafrB3ss3SKhERERUWBneJiIiIiIjIIu+yHvDy90DE7WikJGTm3Y2LTkDolfsoW80HGvfMHrlVGpSFIAi4eOyWtOz1s5k5eKs3CpDmERERUeFicJeIiIgK7J8b9/HPjfvy2UREVMw5ONuhXA1fxNx/hIcR8QCA8JtRiH+YiBqNy0OpzPxJWapsCbh5OePm+TAkJ6RCm65D6NUIOLs7wre8p6xVIiIiKiwM7hIREVGB9N9yEl3/OYqu/xxF/y0n5cVERFSMCYKAao0CkJ6mxf2QzEHVLh67BWsbNcpW85HqmQZfi7wTg4cR8UiMS8bd65Hwr1IKLiWcsrRIREREhYnBXSIiIiqQJZfCLL4mIqJXQ6myJWBto8atf8ORnJCKm+fDUCqgJEr4ukp1TIOvpSanI/xmFO7fiUZcdAIq1y8DtbXKrD0iIiIqPAzuEhERUYE08fWw+JqIiF4NJXxdUSqgJEKv3kfo1Qjcv/UAvuU9YedoY1bPNPjarX/DcefyfSgUAirW9jerQ0RERIWLwV0iIiIqkFWdGmJ4/QoYXr8CVnVqKC8mIqJizt7JFuVr+eFBWCyungxBepoW1RoFQBAEs3qmwdfCrkbg8vHbcCmpgYePi1kdIiIiKlwM7hIREVGBeDnY4KfXauKn12rCy8G8FxcREb0aAmr4IjEuGYc3n4ODxg5lqpSSV5EGX7t5PgxXT4WgbDUfaNwd5dWIiIioEDG4S0RERERERLkqU6UUbB1s8OBuLHwreMLNy1leRRp8Ta83QJuhQ43G5aFU8icnERFRUeJfWiIiIiIiIsqVm5czfCt4AgAq1vGHjZ2VvApgDAI7aOxgbaNG2Wo+8mIiIiIqZIIoiqJ8JhEREVF+rLkWDgDoUclXXkREREUgMjISzs7OsLW1lRcVyJ3Hl3Aicqt8dqF6I7QvXE9kBoqz+RiAu3wmERFR8RUTEwMbGxs4OhZNqiL23CUiIqIC6bv5JHpuOI6eG46j7+aT8mIiIiIiIiIqIgzuEhERUYEsuxxm8TUREREREREVLQZ3iYiIqECa+j55fjbrayIiIiIiIipaDO4SERFRgazp0hifB1bE54EVsaZLY3kxERERERERFREGd4mIiKhASthZY1rLGpjWsgZK2FnLi4mIiIiIiKiIMLhLREREREREREREVAwxuEtERERERERERERUDDG4S0RERAW28mo4Vl4Nl88mIiIiIiKiIsTgLhERERXIu5tO4O2Nx/H2xuN4d9MJeTEREREREREVEQZ3iYiIqED+unrX4msiIiIiIiIqWgzuEhERUYE09/Ww+JqIiIiIiIiKFoO7REREVCBruzTGlw0r4cuGlbC2S2N5MRERERERERURBneJiIioQNxsrfBD8+r4oXl1uNlayYuJiIiIiIioiDC4S0RERERERERERFQMMbhLREREREREREREVAwxuEtEREQFtvzKXSy/clc+m4iIiIiIiIqQIIqiKJ9JRERElFf/23gCK65mBnbfrlwaf3VqIK9CRESFLDIyEs7OzrC1tZUXFcj9pFtQK4o2f7ryhhpuzl7y2ZmsAHjLZxIRERVfMTExsLGxgaOjo7yoUDC4S0RERAWimLoWpq8TgiDAMKq7vAoRERWyogruEhERUeEq6uAu0zIQERFRgbQs7WHxNRERERERERUt9twlIiKiAolLy8CPJ28AAEYEVoCLTdE+zktEROy5S0REVFwUdc9dBneJiIiIiIiKGQZ3iYiIioeiDu4yLQMRERERERERERFRMcTgLhEREREREREREVExxOAuERERFdjiS6FYfClUPpuIiIiIiIiKEIO7REREVCC9NxxH/y2n0H/LKfTecFxeTERERERE9ErQ6lLks144DqhGREREBSJMWWP2/0eDXzf7PxH9d1y9/ScMolY+m4iIiOiV4GhfBqW9Ospn5yo5ORlqtRpWVlbyokLBnrtERERUIK/7l7T4moj+exSKovnRQkRERPQyUKsc5LNeOAZ3iYiIqEDWdWmEb4Kq4JugKljXpZG8mIiIiIiIiIoI0zIQERFRoYqLi5PPIqL/iOt35kOnT5bPJiIiInoluGqqw8ujpXx2rpiWgYiIiIiIiIiIiIiyYXCXiIiIiIiIiIiIqBhiWgYiIiIqsAX/3gEAvF+jDNMyEP2HMS0DFZXlSw5j1szd8tk5KumpwewFA/HwYSI+GbwQLV6rgjHj35JXK1aOHLqOL0esQNs3a76025KakoFRn/2Fa1fuY+ac91CxkrfFeU8zafw/2L/3Sp7ry6Wna7Hqr2No3KQCAsp7yotfSab9fP5sqLzITCkfV7RuWx2932kMe3treXGRiX2YiA/enwcvbxdM/el/sLUrmsfTiYoa0zIQERHRK6fnhmMYsO00Bmw7jZ4bjsmLiYiICszewQYlSmrMJkdHWwCAQiHAw8PJrKykpwYKJX/u/lf9Mm0bli48BL3eIC/6T6jfoBw6dqmbbapTtwwiI+KwaN4BjBm1CslJafJFiagYKlZ/7VJTUzF16lTs3p3zHdsbN26gQ4cOsLKygiAIEAQBAQEBuHLlirxqvoiiiKlTp0ptVqhQAWFhYfJqFmm1WgwePFhadurUqTB1mD5z5gzs7e0hCAL69esnX7RAjh07Bmtra+l9V65cKa+So379+knLPW1ydXVFw4YNMX/+fCQkJMibArK0Z29vjzNnzpiVZd0HeZmUSiXKlCmDIUOG4MqVK9K+tEQURezduxcTJkyQF+UoPDwc48aNQ40aNaBUKs22s23btti0aRO0Wq18sWIjNjYWv/76Kxo2bGh2nTg4OKBhw4ZYtmwZkpNz7nETGRkJPz8/CBbO2aI8nwtbQkICvv76a1y4cEFelM3y5cshCALat2+PtLTC+QIkiiJ++OEHCIKAKVOmyIstSk1NxZIlS8zOTaVSiRo1amDJkiVITU2VLwIAmDJlSrbrKC9TixYtkJSUJG8Oer0eW7ZsQZMmTczWo0mTJtiyZQv0er18EYvCw8MxYsQIlCpVSnpPV1dXDBo0CHfuZPYCzU1CQgKmTp2KgIAAaXkHBwd06NABJ06cyPWzwaQw2pB7lmMLANevX8d7770HBwcHaV38/f0xderUHD9bnybr3y5Ln79yz7K9ALDm2j2Lr4mIiApLl671sG7TcLNpzPguAIAatfywfM0nZmW/z3kfLi728maomBgz/i3sOjjmmXrtAoBOl7fvo6+qbj0DMWp0x2zTr3/2w/I1wfDxdcOZUyHYvvVf+aJFxs3dEWs3DseMWf3Za5eokBWb4O7Vq1dRp04dfPHFFzkGMM6dO4cGDRpgy5YtZsE3a2trlChRwqxufgmCgODgYHTo0AEAcPPmTYwfP/6pQT5RFPHzzz9j7ty5AIAePXpg+PDhEARBXrVQiaKI5cuXIyMjQ5q3aNGiHPddQcTFxeHEiRMYOHAgfH19sW3bNnmVQmUwGBAaGorZs2ejatWqGDlypMXjkJycjPfffx+tWrVCSEiIvDibGzdu4I033kDp0qXx7bff4uLFizAYntzpjYuLw44dO9CpUydUrlwZe/fufeZAyIsQFRWF9957D+7u7hg2bBhOnDhhtt+Sk5Nx4sQJ9OnTB97e3lixYkWeg3TFzd69e1GuXDn8/PPP0Ol08mIzoihKN5SaNWsGGxsbeZVn8vfff+Prr7+Wz87RvXv30KxZM/Tr18/s3DQYDLh48SL69euHOnXq4OrVq/JFC9Xjx4/Ru3dvdOjQAUeOHDFbjyNHjqBDhw7o0KEDHj9+LF9Uotfr8euvv8Lf3x8//fQTIiIipLK4uDjMmzcPFStWxJIlS3K8xkzH8IsvvsDt27el+cnJydiyZQsaNmyIt99+O9f1KIw2LMnvsRVFETNnzkSVKlWwaNEis5srYWFh+OKLL1CuXDmcPn3abLm8OHz4MMaOHSufbZFWq8XIkSPls/PkjTJPHnfM+pqIiIiIXi4+vq7o/U4jAMCJYzeRkZH77yEievkVm+Duxo0bce3aNflsMwsXLkR8fDwAoGfPnrhx4wYiIiKwceNGODs7y6vnm62tLWbMmIGyZcsCxmDpqlWr5NXMnDp1Cj/88AMAoHz58pg2bRrUarW8WqGLjo7Gnj17AAD+/v4AgH379uH8+fOymk/XunVrDBo0KMcpMDAQCkXmqZSQkID//e9/OHnypLyZPOnevTsiIiJynMLDw7Fx40a0bt1aWubHH3/Eb7/9ZtYOAFy7dg2rV6+Wz7Zoy5YtqF27Nnbt2gUAKFeuHCZNmoSLFy8iIiICoaGhWLp0KerXrw8AuH37Nlq3bo3ff/89x+DTy+T8+fMIDAzEokWLAAAeHh4YNWoUTpw4YbZf27ZtC2Q5jl988YXFwHlxt3jxYjx8+FA+26LY2FicPHkSKpUKQUFB8uJ8MwU2e/bsaXbzIDePHz9G//79peBe/fr1sW3bNkRERODgwYPS9XDt2jV07doV9+6Z95ysUaNGtuvW0vT+++/Dy8tLWu7NN9+Evf2THi+m4N/atWsB43ocPHgQERER2LZtm3R9bN++HYMGDbJ47oiiiB9//BHDhg2DwWCAk5MTpkyZgtDQUNy4cQPDhw+HWq2GVqvFe++9Z/FJjZMnT6Jbt27SMRw4cKB0rWZdj1WrVuV486cw2pB7lmMLYzD4008/lfbHvHnzEB4ejosXL2LgwIEAgIcPH6JTp064deuWfPEc3bt3D4MHDza7yZcT043IH3/8UV6UJ3+/1QjfNq2Gb5tWw99vZf5YICIietmE3onBqOHL0bzhBDQNHI93eszEgb3mTwLGPkxE904/Y+L4f7Bn1yW0bfkDmgaOx/BPluDRo8wnmtLTtdi+9QLe6TETTQPHo2ngeLz/7iycPX3H4m+D5OR0/DljF95q/6NUv3Xz7zFl0kbERGd/Oic9XYsVS4+gXavJaBo4Hh3bTMOeXZdgoWlcvxaB1s0mYfmSw7h44S4G95+LpoHj0bzhBIwavhyhd2LkiwDSvvgLLRt/99T1AYCY6ARMmbQRrZt/j6aB49Gy8XcYNfyvHNsvqEnj/0HrZpNw/dqTTgCm/dipzTRpP/bo/AtWLD2CtNTM7zum/bF96wWkpWkxsO+cbO3EPkzEzF92SPu3ZePvMP7rtRa3fdL4f9C908+IuB+HFUuPSO/dsvF3mDJpIx7FZn/KzWAw4NiRGxjy/jyz9dy9w7zzEIy5pFs3m4Sjh29gzKhVaBo4Hu1aTcaWjefM6hU2dw9HAEBqqhZ6nfk63Q17iO+/XW92bvz203YkJ6eb1TORX1d9e/+Bs6fvYNniw2gaOB5HDl0HslxbwUMWITXF/Ptpfs6vZzkmRK+6YhPcfZrk5GQp9UKpUqXw448/onz58vDy8kK5cuWgUqnkizwT0yOypmBmcHAwzp2z/MEbFhaGd999F/Hx8bCyssL8+fPh5+cnr1Yk9u3bh2vXrkGlUiE4OBh2dnbIyMjAihUr5FWfKjg4GHPmzMlxOnHiBMLCwtC0aVMAQHx8PL777juzXsKLFy+GKIpITk5G3bp1s7Ruzs7ODl5eXjlOPj4+6NixI3bs2GH2uPOsWbMQHh5u1lZerVu3Dp06dUJKSgrUajXmz5+P69evY/To0ahWrRq8vLzg5+eHd999FydOnMDq1auhVqthMBjw6aefWgw+vUxOnjyJli1bSvtnwoQJCAsLw5QpUxAYGGi2X7dt24ajR4/C3d0dMAbO58+fL2sxZ3Xr1kVycjJEUcTixYvlxcXS+fPncfXqVVSoUAGVKlWSF+fLvXv30KNHDymwmVdLly6VbtYMGjQIR44cQdu2beHl5YWmTZtiy5YtGDRoEGAM8M6bN89s+Xbt2mW7bi1Nb7zxBqKioqT3kT9lsGnTJqlt03o0bdoUXl5eaNu2LQ4cOIABAwYAANauXYtNmzZJy5qcPXtWSpNSvnx5nDlzBqNGjYKfnx/Kly+Pn376CStWrIBCoYDBYMCPP/5o9lmSmpqK7777DvHx8VAoFFi7di3mzp0rXaum9TClBZk3b5607wqzDblnPbbh4eH48ssvYTAYUL58efz7778YMGAAfHx8UK1aNcyZMwcrV66EQqFAZGQkJk+e/NTe5jAG4r/99tun3hQ1uXz5snQj8lnYq1UY27gyxjauDHt14fy9JSIiKkz7913Fe+/Mws0bUXijbQ3UquOPu2EP8fWXq7Hhn+xpiw7svYKJ4/5B+YpeeO31qvAu5QKNxg7JSWmYOP4fTBr/Dx5EPcZrr1fFa69Xxd2wWHz60WKsWHbULMAbFRmPgX3n4K+lR+DiYo+OXeqi7Zs14eBgjc0bzmL4J0vMgorJSWn44rMV+GPGLqhUSrR9syb8y3hgwtfrsGjeAame3IVzdzHqs78QGRmPtm/WRLUavjh25Cbee2cWDux78mSXKIrYtvk8+r39B44duYFqNXzRsUtd+Pq6YvOGs+j79h+4evm+WdunT4ag79t/YPOGs/D1dUXHLnWN7d9Av7f/wIG9BUuBmBfJSWkYM2oV/lp6BO4ejujYpS5ee70qkpLS8MeMXRg3Zi3S07XQaOzQoXMdeHo5QxAENG1eCR0614FGYwcAuHkjCoPfm4dVfx2DRmOH9p1qo1oNX+zZecnitsMYVP7+2/WY9ftuVK3uk3n8HG2wecNZDPt4CeLinjx1pdPp8eeM3Rg1/C9cvXIfTZpVRNs3ayItTYsJY9dh+uQt2VJGpKfrMG7MWvx74S7avlkTZcuVgH9ZD8AYyGwaOB7Llxw2W2b5kszA6aTx/5jNzwtRFHHyeOZTa+UCSpqlSDh25AYG9JmNbZvPo0xZD+ncWLPyOAb2nYOoyMzOdCYHjNfVsSM3UaOWH9q+WRNxcckY/skS7Npx0axuTp7l/MrPMSH6L3hlgruiKEo/eMuUKQONRiOvUmg6d+6MwYMHA8ZA5siRI7M9tqvVajF+/HjcvHkTAPDdd9+hSZMmZnWKSlpaGpYuXQoAqFChArp37y4FXnft2oUHDx7Ilig4Hx8f/PHHH1IP6aNHjyI0NPdROgvClCajTZs2AIBbt27h0qVL8mpPde/ePXz99dcwGAxQKBRYsWIF3n//fSiVSnlVwPi+PXr0kHoKGwwGjBkzBrGxsfKqL4WkpCR8/fXXUo/2KVOmYOzYsbC1zRx8wpJGjRrhr7/+km5gfPvtt/nqLfiq2bdvH0RRRKtWraSgd34lJCTg22+/RdmyZfHPP/n7AqbT6bB//34AgLOzM4YOHZqt979arcZXX32FUqVKAQD27NmDxMREszpPc/LkSQwZMgSiKKJ+/fr44YcfzN4nNTUVc+bMgSiK8PLywqhRo7Kth62tLcaNG4eAgACIoog5c+aYBWZ1Oh2mT5+OlJQU6YZXQECAWRswfsb27NkTAHDo0CGznOlXrlzB3r17AeMTGp07d5bKTGxtbTFw4ECoVCqIoij1NDYpjDZMCnJsAWDDhg3S9fXNN99kuwEoCAK6deuG9957DzDejMpLwHbVqlWYO3cunJ2ds7VpyYULF6TPCSIioldRWmoG3ukbhDUbhmHM+LcwY1Z/jJ/UHYIgYNvm89kGlkpL0+L9wS3w25/9MOH7Hhj5VUcolQqs//s09u+5gpq1/LB6/aeY8H0PTPi+B5at/hg+vm6Y/ftunD6ZmRJOFEWsWHYU98Jj8WFwa8xf+gFGje6IMePfwpoNw9CuQy2EhT7EubNPfjet//s0zpwKQavW1bBm/afSuv4x933cv/8oyxqaO3bkBho3qSAtM3P2exg/qTsMBgPmz94nBbtu3YjCbz9vh529Nf6cNwAzZvXHqNEdMX/pBxj9TRekJKfjh+82SPVjohPwy/StSElOx/hJ3aVtmDGrP36e2Rd29taYNnkz7oREy9aocB05fANnToWg3/vNpHWY8H0PrPx7KKpW80FkRDwiI+Lh6eWMT0e0Q63afrC2VqHfgGb4dEQ7eHo5IzUlA7/9tB3RDx7jvUEtsGz1x/jy686YMas/Jk7umW3bTZIS0xD7MAnL1wTjh+lvY8z4t7ByXTBq1vLDnZBoXL745Km5I4duYNVfx+BdygVLVn5kVr9u/bLYtP4MNst65YqiCFdXe8xdNAhjxr+F3+e8j6rVfMzqFAZRFPEwJhGzf9+Dv9echIOjDdp3rCWVRz94jN9+2o6MDB1Gf9NF2s/zl36AD4Nb4154LH6ethXp6ZlPtEU/eIxZM3fBYDBg4uSemDGrP8aMfwtrNw7Dmx1qI+TW0+MOz3p+5eeYEP0XvDLB3axMA/wUFZVKhYkTJ0qP7e7Zswd//PGHdIdWND7eanoE/nnl2TW5desWjh49CgAIDAyEr6+vlCv42rVr2Ldvn2yJwhEQEIDGjRsDAB49elTkAUFbW1s0a9ZM+v+//+Y/Gfy8efOkQElwcDC6du0qr2LR22+/jQYNGgDG1BtHjhyRV3kpbN68WUo10aFDBwQHB+fpPGzZsqUUXIuMjMS6devkVf4TkpKSpGP7+uuvy4vzLDg4GOPGjZMe7a9UqRIWL14MO7vMHgS5SUtLk1IHODk5wc3NTV4FMKbaMAVK7969a3EgtJw8fvwYo0ePlp4y+PHHH7O9T2hoKE6dOgUAaNy4sZTuRc7X1xft2rUDAJw4ccIsl214eDgOHToEAOjatSsaNbL8+L5KpZI+szQaDeLi4qSytLQ0tG7dGi4uLujQoUOOT2WUK1cO3t6ZA3DIUyoURhsmBT22phzlnp6e0uennEqlQu/evSEIAuLj46Vgf05OnjyJ4OBgAMBXX31l9jmZE3kqDyIioleNn787uvYMhEr1pBNHvfpl4efvjtiHiUhLM/9b7+Rki2YtKpl9d370KAnbNp+HWq3EkODX4ermIJV5ejlj2OftIIrAlk3noNcbkPA4FbdvPUBpP3e0blvdrC2VSgk//8yOA6aeu4mJqdi/5wocHG3Q9/1msLF90qOyWg1f9OjdUPq/nIeHEwZ+0FJaRhAEtHitClq+XtUs2LVl03kkJabh3X5NUK2Gr7S8IAh4o10Nqf6pE5nf4Y4evoGw0Id4s0NttHititk21Assix69G+JxfEqee2k+K9M+cnHNHHjWRKOxw6wFA7Fk5UfwL5PZ2zUn/164iwvnwlClmg96vd3Q7Fxo1rIyunSrjzsh0Th7KvvAvh0614aPr6v0f3sHGzRqUh4AEBaamTogPV2Ljf+cgSiKGPxhK/j5P1kfewcbDP2sLRwcbSzeTGgUVAGeXtlTSY4Z/xYOnRyPd/qadxR7p28THDo5HmPGv2U23+TLESuklBCmqVmDCXir/Y9YvuQw/Pzd8dOMPihf8UlKtkMHruNe+CO0fL0q3mhXQ9rPgiCgW89ABDYMwKkTt3HzRuaTfufPheFe+CN06VYfzVpWltqxtlbjvUHNzfZXTgpyfuXlmBD9V7z0wV3TKO9ffvmlNK9Tp04QBAF+fn7YsWMH7O3t4ejoiAMHMh9TOXDgABwdHaU6kZGRWVosHG5ubpg5c6bUU3X8+PE4fDjzUYmsA9g8zzy7JmvXrkV8fDwEQUCvXr0gCAJatWoFV9fMD76iGljN2to6x9zG/fr1g5DH0doL6syZM7C3t0e9evWQkpICAFiyZAkE4+jzpkfFHz58iDVr1gDGdBB9+vQx+2OSG41Gg759+8LPzw9Dhw6Fp+eTAYRM7296r9TUVMyYMQP+/v7SOpQsWRIjRozI8dyMjIyEn58fBEGQHg3PiWnfys/1rD24YXyMPrceu1mpVCr069cPvr6+0kBdeRlcLeu257beCQkJ+PHHH1GjRg3pZoxSqUSDBg3w119/5Xh+yvetXq/Hli1b0KRJE6kdKysrtG3bFgcPHjR7JA5Z9tWSJUsAACkpKahXrx4EQUCLFi2yBURNeU89PT1RrVo1s7JnoVarMWXKFJw9exZVq1aVF1ukUqmkvLcZGRk57pu0tDTExGR+iSlRokS+Bn7Lmnbgww8/tPiUwa1bt/DoUWZvkdatW+cYEAWAFi1aAMYnGy5cuCDNv3btGu7fz3zULbegKgC88847EEURERERZoH1oKAgrF+/Ho8ePcI777xjtkxWt2/flgZqk3/+FkYbcs9ybOPi4qQnDurUqSMFki2pVKmSVL5///4cUzNkDdTn54ZOTp/d+THnfAjmnH/64JVEREQvgruHE+yyBEsBwMpKBWcXe2i1eqSnm/9tdXVzgLPzk7EHAODe3Ue4F/4IlauWQtly2QfsLluuBNzdHXH18n08fpwCjbMdZs5+D8vXfAIXF3vExCTg5Inb+GvpEQz7eAnmzTLvdPMg6jHuhsWibLkS8PTK/iRqjZqlc/y7XquOX7bgoFKpQOMmFQAAVy7fQ1pqBkLvxMDaWo069cqY1TXVb9kq83vMudOZvYmvXcn8LtQwqDyUyuzhg0ZB5WFtrcaVS/elvLdFoUpVHygUAn6ZvhVjv1qNI4euZwuQPs2F82EQRRFNmlaEvYP5d2VBEKR9Yup5nVW5gJLyWdmCybEPk3DrZhQ8PJxQo1ZpszIA8PTSoGy5EggNiUGkLL1BGWMahsJSv0E5dOxSV0oD4uySeS7XrOWHZas/xtJVH6Nylcyn/gBArzdIPchbtqqa7VhbW6tRq7YftFq9dKPAdI4ENiyX7bz0KOFk1n5OCnJ+5eWYEP1XSFfP5cuX4eHhIQWfcprKly//zLlNXzWBgYHSiOgZGRkYMWIELly4IA1g87zz7EIWsKxcuTLq1asHGIPMb7zxBlCAgdWeJj4+Hjdu3ACMvWqzBjyLQmpqKg4ePCj9v0aNGmblT3Pt2jVpfevXr4+KFSvKq+Tqo48+QmhoKH799VcEBgbKiwHjPunatSuGDh2KsLAwaX50dDR++uknBAQEYMuWLWbLFJaIiAicPXsWAODn55drrmNL2rZti7t372LRokVo3bp1jqkq8mvv3r0oV64cPv/8c1y8+GRQAYPBgJMnT+Kdd95BnTp1cPXqk9xglsTHx6N3797o0KEDjhw5IrWj1WqxY8cONG/eHB999FGOPS7z4uDBg3j06NFTA29P4+joiFGjRuHhw4cYNWpUnoPsAGBjYyP1hI2KisKqVauyBa1hHMTMlL4gKCgoz8G6W7duSYNoBQQEYMSIEdm+nMGYysCkdOnsX1az8vX1lbYx63lvyk/u6OhYKMHynKSmpmLevHnQ6XQQBAHdu3eXV3mqvLZRkGMbEREh9cr28vLKNSCv0WhQpkzmD44HDx5YDPKLoojp06djz549KFu2LGbMmJHn9cnPOWNJ9/XH8MGOM/hgxxl0X39MXkxERPTCeXg4mvWEzSolOR2JieZ/W51d7GFlZX4jOjExFXq9AVGRjzHj5x2Y+v0ms2nurL3I0OrwKDZJ6mmalpqZCqBVk4no2v4njAheij9n7MLlS/egcTZ/0kevN8BgMMDF1QHW1tlvLDs62cDa2vLN8TJlS1j8Dmdvbw0AiHmQgPR0HWIfJsLJyRbu7pkDasm5uNhDqVQgMSkNqakZiI1NhLW1GiVKOsmrAsaes/YO1khOSoNWlku2MNWq44ePh74BhULA/j1X8OWIFWj72mS802MmVi4/muNgX1nFPMg8JidP3M527KZ+vwk7tl6AIAi4d++R2YBfNjZquLiaB/otSUxMRUpyOlLTMjD3z73Z2p/5y05EP0hASko6HkSZp3U0DXBWWLr1DMSo0R2fpAFZ/yneaFcDF86H4edp25Ai218Z6To8jk+BIAjYuulctnWf+v0mKfh760YU0lIzEBkZDxsbtcV1FwQBZcpmvwGSVUaG7pnPr7weE6L/Cim4W6VKFQwZMsS81IIRI0bA1/fJ4xtF7eOPP0ZERIQURIVxcK6IiAicOnUKzZs3R0hICG7evImGDTMfU2nYsCFu3rwp1SlRIvcPlYIYOnQoevToARgfz3/ttdekx/yfZ55dk9OnT0tBsR49ekg5QlUqFd5++23AGIhevny5xQDRsxJFERs2bMDp06cBY95jUyCiqGzdulVKOVCpUiXUqVMHMAZ5Q0JCsH37dimw0b17d0RERCAiIgKtW7cGjDc0TL3fAgIC4ODw5LGqwjJs2DBs374dnp6eWLhwodQTdODAgYCx52inTp2k7ShMISEhUn7lgICAAgVuCsumTZvQtm1bKaDVp08fHDlyBBEREThy5Aj69OkDGAPvnTt3NgsMyg0bNgxr165FrVq1sHr1aoSHh+PGjRsYPny41Mty9uzZZoN6/frrr4iIiJACdba2tti+fTsiIiKwbt06qYcsjPlhTcelU6dOuQbenmbmzJmYMmUKnJwsf2l5mj59+qBVq1YAgNGjR+Orr76Sju3jx48xefJkvP/++4DxRk5e08CIooiZM2dKPb5z+3w35Q+3s7N76meqQqGQ3v/69czRcbO+dnFxQYkSJZCamoolS5aY9eB2cHDAe++9Z7ZcXiUlJWHDhg1o3LixNKDfwIEDpX2XF/ltoyDHNjo6Wnq6oHz5zEfIciIYe7cjl7Qbf//9N77//nsoFApMnTo1x9QZllStWhVfffWVfHaerbv+JK1D1tdERESvougHj7F5w1lsWn/GbNq2+Twex2f+bYfxMf1xY9Zizcrj8PP3wGej2mPxio+wbc+X2HVgNHr0zkzzVhgUyqd/98sPK7USCkXe27SyVlnseVlYBEFAz/81wtbdX2DshK6oU7cMFAoBd8Me4vdfd1oc7Csn58+GZjt2m9afwaED1wrlN3JSYhq2bTmfrf1N68/keR0Lm42tFYZ+1haVq5TCmVMh+P23XdkGdoPx98HRwzeyrfem9WekVB2Z9TJvRjwvRX1+ERV30tUhCAIGDx5scXAbk/r160uBzOfFwcEBXl5eZoE3FxcXeHl5SY8elyxZEp6enrC2zrwraW1tDU9PT6lOYfU4tEStVmPatGnSD3PTY8vPO88ujMGoxYsXQxRFWFlZSYONmTRo0ACVKlUCAKxfv75QBjxLS0vDpUuXMHjwYGnAHxjzUD7r4FO5Mb3foEGD0L175gABCoUCEydORMmSmY9lqNVqlCxZEu7u7tL+t7Ozg5eXl1nvuKyDoD0tsPKsHj16hKZNm+LcuXPo378/fHx8UK1aNcyZMwcrV66EQqGAwWDA2LFjsw3KV1AJCQnSl5NSpUqZBS5fhOjoaHz11VfQarVQKBRYuXIlFi9ejMaNG8PLy0sKppn2y82bNzFy5Mgce94+evQIAwYMwNGjR9GjRw/4+PigfPny+Omnn7B06VIIggBRFPHXX39JQXxnZ2d4eXlJ+VAFQYC7uzu8vLzg5uZmdr0+ePAAp0+fhp2dndQD/kXRaDT4+++/ERwcDIVCgSlTpsDT0xOCIMDZ2Vnarz169MDBgwfz/LTA5cuXpQBm1apVc805ndNxsKRkyZJwcXExm5c1d3CJEiUQFRWFZs2aoV+/fmY9uJOTk7Fo0SJUqVIFM2fOzNMX7JCQEHh5ecHR0RFdunTB+fPnYWtri+nTp+P3339/akoFFFIb+ZVTagVL7O3tpQHzLAkLC8NXX30Fg8GAwYMHWxwkLjeCIGDEiBH466+/5EV50qbMk8fisr4mIiJ6lTg62kKpVKB1m+o4eGIcDp0cb3HadXAMKlbyxvWrkThx7BbKV/TCjNn98Vb3+ihbrgQcHC13GlAqFVAoFIh7lCQNWpWVwSAip69GoSGWc4w+jMkcZNejpBOsrVVwc3dEQkIqHj60PPhuVFQ89HoDrK3VsLZWw83NEenpWkQbe73KPYxJxOP4FKjVKiiew29fewcbvNGuBn79sx/2HR2LP+cNQPkKnrgXHoutm3N/OtXD2Dt04pRe2Y5Z1mnGrP6wtbPcyzs3jo62sLO3Ro1apbFj/1fZ2s06BTXN31OjhUGjscPHw96AWq3EpvVnsHfXZanMyloFjbMdrK3VmLNoULb1zTqNGf8WbO2s4OmpQVqaVjrHshJFMdsgaHJWVqqX7vwiKq7Mbn34+vpixIgRWWdJBGPeW/kgO5T5yPvIkSOl/yuVSnzwwQdFEgzIzc2bN7Fz507AOCBWrVpPRr6EMeBiCt7cv38/z+kATDmOLU22traoXr065s2bJ9UfNGgQBgwYYNZGXmXNjWtpkr+fWq3GwoULcw1K5SRrz0Afn8IfjRTGoPLPP/+cLUWFIAjo2bMnPvzwQ8A4AJJpsKrCYuppiSLcvvz4559/cPly5heI0aNHo2fPnhBkf6AFQUC3bt2kGwVbtmzJcaC8UqVKYcyYMRYfO2/VqhUqV85M6v/w4UOkpeUvHxcAXLx4EXfv3kWVKlVQtmxZefFzl5CQgORk85F7szL1ls1rwFAURcybNw/x8Zm9Bz755JOn9sjNK6VSme2mmk6nk9b//v37aNu2LU6fPo369etj27ZtiIiIwMWLFxEcHAy1Wg2DwYDg4GD89NNPZu1YEhsbm+3mSGpqKv78809s3LgxTwHiwmijKAmCkGOOYq1Wi5EjR+LmzZuoX78+Jk6cmGPd3GRkZEgB+Pxa37UJJjWrhknNqmF91+f7xAoREdHz4lXKGR4lnHDjeiTi4nL+XmZiSuNQrlwJaDTmKRjS07U4a8xZauLl5Qz/sh64evk+Qm5nD4xduXTfYtAXAK5euY9Hj8yf7DHlURUEAXXqloGNrRX8y3gY3zv7oGF6vQFHD2emratUJTMlmenf40duWuypefpUCPR6A/zLeOSY9qKgtFo9Jk/cgM7tpuPf83el+QqFAtVq+KLfgOYAgMj7TwbhtaRK1czfRGdP3ymS73YuLvYo7eeOkNvRiIos3I47haVGzdLo0i1zYPiF8/Yj+kHmeiqVClSu4o30dC2uXMocI+NpTAPynTx+O9v+jItLxo3rlseXyeplOL+IXgXZ+rX36NED9etnXuxZtW/fXsr7SObCwsIwbdo06f96vR4TJkzIFigoanv27JF6Dr/11lsWg14dOnSAlVXmh+KyZcsKdR1tbW0xdepU/Prrr0Ua2FYoFKhevTqmT5+Ohw8fom/fvtmChPllesS9sLVr1w41a9aUzwaMwZrevXtDpVJBFEXs22c+oEJhKqrty6usKQ7s7OzQpUuXHI+ZSqVCp06dAGPaCtNAiXI1a9aEl9eT0V2zsrGxgYdHZjL9O3fuIDEx+93kp9m6dStEUUSjRo2y9UJ93nbv3o3q1atjwYIFMBgMZuksTpw4gcGDBwMAVq9ejYoVK+bpxk1oaCjWrl0LGNOavPWW5ZF2i0JkZCSio6Mxbtw4HDlyBG3btoWXlxeqVauG3377Dfv375fSiPz444+4deuWvAkz3t7e2Lp1a7YA8e3bt9G9e3dMmzYt2xdOucJo40UQRRE///wz1qxZA2dnZ8ycOfOZbsJGRUWhTZs2GDp0qLwoT2xUCoxuVBmjG1WGjSrbVwsiIqJXgru7I9q0q4Gw0IeY+8des0CrKIrYs+sSWjT6FkM/XISEhFSpp++/F+6aPY5vMBiwesVxnDxu/h3H3sEG7TrUglarx9w/9iIh4Uke4KuX72Pe7L1m9bMKC32I9etOS4/ai6KI/XuvYN/uy6hZ2w9Vq2UGNtt3rAUHRxssW3wYl/59MpaOKIrYue1f7Nt9GX7+7tJAbI2bVICfvzu2bj6H/XuvmH0fOn0yBGtWHoeDow3adzTvWFSY1GolylfwwqPYJGzecNZsv+t0ehw5lNlpxxRsNElP1yEx4Uknj+o1fVG5SimsX3cKe3dfNtuW9HQtpkzciOYNJ2Dh3P3S/PywtbNCx851kJSYhl+nbzM7fgBw6d9wtGs1Ge/2nIl74Zm/2583QRDQ8+2G8C7lgnvhj7By+TFpPzRvWQVu7o6YN3uv2bkBAAkJqfj0w8Vo2fg77N5xETAOdubj64qtm87hVJZB6HQ6PebN2oew0Kd3GngZzi+iV0G2X2Bubm748ssvzQIvdnZ2+Prrry0GC//rUlNT8cknn+DmzZsQBEF69P3QoUP48ssv89yLrqAeP36MZcuWAcYejabB0+Rq1aqFli1bAsbeoseOPX3gm9atW2PQoEEWp5EjR2LFihW4cuUKEhMTMXLkyAKdJ1lz45qmrIEWGPP5/vLLL/jss8+eKc+lSZUqVaTX+XnkPD/q16+faw86f39/KZ3E9evXC/V8eR7bl1fx8fFSLmhnZ2eoVCpERkbmOFlbW0vn0YULF2StZXJ3dy9QHtzcxMXF4dixYxAEQQo0vyjh4eH48MMPER8fD2dnZ+zatQtLliyR0lkEBgZi9uzZOHToEJydnZGSkoJBgwY9NSB69OhR3L+feVe+V69e0nlYGFJTU5GRkftoyc2bN8fnn39u8UZQ48aN8d133wHGQPC6devkVcyUKlUKLVq0yDFAPHbsWBw/fly+mJnCaKMoZe35nNWpU6fwww8/AAC++uorizdnn0an0+Hrr7/GoUOH5EVERESUhSAIePvdxmjQKACbN55Flzd/xLjRazB54ga82/N3jB+zNvNJtJ4N4ORki/IVPFGnXhlE3I/Duz1/x7jRazBp/D/o3O5HLJizH292rA2lUmGWUqFDp9ro2KUuzp65g67tf8K40Wvw1ecrMGTAPOh0hhw7SNjYqLF4/gG82/N3TJ64AQP6zMb4MWvh6GSLjz99Q0ozEFDBE0OHt0VKcjo+HDgfwUMWYer3mzCgz2x8/+162NlbY9jnb8KjRObvLI8SThj2+Zuws7fG+DFrMaDPbEz9fhOChyzC8E+WICU5HUOHt0VAhacPpp2WpsXnny5Ht44/W5w+Hrwgxx7Rbd+sgbr1y2LblvPo3ukXTBr/j7Tft20+j7r1y+L1N54M2Otf1iPzJvi0rfhl+jbcv/cIGo0dPvuiPRydbDF+zFppX40bvQZd3vwRmzeehZ+/B9p1ePZA4mutq6JL13o4e+YOOredjq9GrsTU7zdhUL85+HDgfCQlpqFD5zoo5ZO3ziOTxv+DpoHjsXzJYbP5y5ccRtPA8Zg0/h+z+Xnh6eWMAYMzYwLr153Cvxcye0P7+Lpi+Mg3kZaqxYcD52NQvzmY+v0mfDVyJTq3nY6zZ+6gbv0yaBSUmdKwREkNhnz8OjIydBgRvBTBQxZJx2TLxrOwsVFDEASLgwOaFOb5RfRfli24C2Nvw/bt20v/HzRoEAIDA83qUObdzRkzZmDz5s2AMTC5e/duKRgwZ84cbNiwQbZU0Th16hROnjwJGB95Llu2LAQLaQ3s7OywY8cOwLj+ixcvfmpAMTg4GHPmzLE4TZ06Fb1790blypWzPYb9LLLmxjVN8kDL7du30apVK/z0008F6kmXNVVB1hQGhSlrgNWSrI+vP2v6gJy4uLhIgeX79+9bDAw9L1qtVhoAKiIiArVq1YK3t3eOU9u2bZGamnmn+0Wse0hICK5cuYLSpUtL6R1Mzpw5A3t7+2zXlmmyt7fHmTNnzJYpiK1bt0qB2i+//BKvv/66vApgDIjOmDEDMAZETTd7LElLS5Nyq1rKz22JpSBsTuLi4pCQYJ43S6VSmeV9fvvtt3MdxLBVq1ZwdXUFAPz777/5vtazBogzMjKwYsUKeZWnKow2cpPbjR+5rDmLTWJjY/HJJ58gPj4eHTp0QHBwcI4/+HJz8+ZN/PNP5g+D5s0zH2skIiIiy+wdbDBpai8M+7wdHBxssHf3ZWzZeA7R0Qlo16EWFq/4EM1bZn5/tLWzwneTe6JH74bQ6w3Yu/sydu+8hEZB5bF4xYfo934zaJztEHonBomJmd99VSolPhv1JkZ/0wWOTpntHz96C292rI0Jk7rD2try94fmr1XBzzP7wspKhS0bz+FOSAw6dK6DhcuGoFLlzEffYQxQZ67nR2gUVAGX/g3HpvVnEB7+CB0618GSFR+hXqB5SrJ6gWWxZMVH6NC5DsLDH2HT+jO49G84GgVVwOIVH6Fdh1p5/g4SH5eM6AePLU4Poh7DYOHRfGTZ7//rEwQBwPatF7Bl4zno9QZ8FNwak6f3hr3Dk44f7drXQstWVXE37CHWrT6BWzczn2SsVNkbC5cNQYfOdRAdnYAtG89h7+7LcHCwwYfBrfHn/AHw9Hr2gahVKiWGj3oTE77vAd/Sbjh84Bo2rT+DWzcfoFFQBcxeMBC9/tcoz/urqDRtXgn1G5SDVqvH4vkHkZqS2TGjecvKWLh8CBoFVcCtmw+waf0ZHD5wDb6l3TBm/FuYNLWX2X5u/loVzF44CDVr+eH82VBs2XgOzs52mDn7PbR4rQqsrVVwdMq9Q05hnl9E/1WCmMMv5jNnzqBZs2bQaDQ4ePBgrgOtPQ9TpkzBl19+CQDYuHEjOnbsaFaelJSEDh064MCBA2jevDk2b96ca+CgMBw6dAivv/46MjIy4OXlJe2nWbNmSblUy5Ytiz179uQ4arlpP6ekpKBv377S4Eb5IYoigoOD8fvvv8uLnsrZ2RmHDh1CtWpP7nICQL9+/bBkyRIgh/2dX6b27OzscPDgQdStW1cqy88+WLduHXr27CkNpLZ69Wp069ZNXg3IQ7tZy5s0aYKtW7fC0dHRrE5uQkJC0KVLF9SpUwc9e/ZE69atoVarzdp92r6LjIxEw4YNcffuXbPzNut8S+uelWnfli5dGsePH5dSFdy/fx9BQUEICwtDqVKlcPz48Xzl3o2Li0OXLl3g4eGBrl27onPnzrC3t8913XLa51mXya+s+yWn9uWyfh7I9wuecj4CwMyZMxEcHIxu3bph5cqVZkG4rOtgSU5tymVtZ/Lkyfjiiy/kVYAs62pra4sDBw7k2jPz3r17aNiwIe7fv49WrVphw4YNFgfSu3r1Kpo0aYJHjx7l+fMy62fw1q1bc03Tc/jwYbRs2RI6nQ6TJk3C6NGjgXx+ruR0beTHhQsX0LRpUyQmJua6P3LzLG3k9dhmrTdgwACz3OVycXFxeOONN3D69Gnp82r//v3P3LM863m6adMmqZ2sxyu//jyXOYLyh7XLIS4u95x3RPTqun5nPnT653tTlui/7Pq1CHwyeCFavFYFY8Y/vzRbRLlJTcnAqM/+wt2wh/hz3gB4l8pbT2Wi4sBVUx1eHpm93/MqOTkZarVaSpNa2Cz23AWAOnXqYNCgQfj0009feGD3ZRQWFoYBAwYgIyMDCoUCM2bMkPbTgAED0KNHD8AYABw8eLDUa7EohIaGYv369QCA0qVLY+HChVizZk2ukyk1Q3x8vJR3szjo2rUrgoODAWOuqoEDB+LcuXPyankSEBCA6tWrAwDOnj2LGzcyBw/Iq5MnT+LixYtYvHgxfv/9d+j1mfmtsoqIiJDPMqPX66XlSpUq9dSgkVxOj2rDOIBeo0aNAGOg9+zZs/Iqubp27RqOHz+OdevWYdKkSVJP2mehUCikAGnz5s2RmJgIURTzNO3fvz/fQb2CSEtLw7Zt2wBjSpL89K4sSq6urjnmGDZxdHSEr29mrjGdTpdjb9ezZ89K+blbtGiRp/1bo0YN6fXTgvRhYWHSEwGVKlWS5ueUf/ppVCrVM92td3Nzk1K35LY/clMYbeTEx8dHGmwxMjIy1577cXFxuHfvHmD8nM/LMXsWWY9XfnT95yg+2nkWH+08i67/HJUXExEREdEr5sih62jZ+DssXXjI7DuyKIo4euQG/j0fhkqVveHmXjTfW4noiRyDu4Ig4LvvvnvmwVVeZVlHJgeAwYMHo3PnzlK5Wq3GtGnTUL58Zi6aXbt24Y8//ijUoEBWWXNnduzYEf369UP37t1znUaNGiUFS1atWvXCB9zKK0EQMHbsWKn3Ynx8PL755ptnCjxqNBq8++67gHHgrqVLl+b5GCUlJWHWrFnS//v06WMx/+vly5fls8xcu3ZNCgDXqFHDYgArOTk5x9QZqampOR47lUqFfv36SW3OnTs3z/tJp9Nh/vz5Ut7UHj16wN3dXV4tzzQajRQ0un79OqKjs4/++7KIiIjA2bNn4erqimbNmsmLUbduXSQnJ2cLQpum5OTkp/bafRaPHj1CZGTuI84mJiYiPDxz8IPcAqInTpwAjHVySvMg5+/vL6VJOH78eI7XiiiKUi5vR0dHlCtXTioLDAyUguVPy18bEREhpSHw9/eXbnxMnjwZpUuXhpeX11PTX4SGhkrXR9abJ4XRRmHIel1cvnwZsbGx8iqSGzduSOtRt25dCIKA1q1bZ8tRbmnq3r07YBz0cvv27YiIiEBISIhZwN7k2rVr8ll58s+NJ6MqZ31NRERERK+mChW9UMrHBXP+3CPlL86af9rdwwkDh7yWa85dIiocOQZ3YfxhXpDBsV5FYpaRyWEcMGvixInZevf5+fnhhx9+gEKRuYvHjh2Lw4fNk6AXhtTUVCxduhQwBmp69+6dY0Anq0aNGkl5lK9du4Z9+/bJq7y03Nzc8OOPP0rd2Tdv3pzj4/lP07NnT1StWhUAMGPGDPz999/yKtmIoog//vgDBw4cAIznQOvWreXVAADbtm2Tgm1yOp0OK1euhCiKsLKyMgskZu3p+uDBgxyDsmFhYbh06ZJ8tqR58+ZS/uzNmzdjxowZOQblstqwYQMWLlwIAPDy8pKC4M/KxsZGeow/KioKW7dulVcxs2zZMqhUKvj4+GD69Ony4iJ16dIlREVFoXr16lIv2BepcePGgPFa37hxY67Hb//+/dKNngYNGlgMRCYmJkq93UuWLJljyhi5MmXKoGHDhgCAHTt24PbtzEfw5e7duyfl9W7cuDEqVMgcaRkAqlevLgW+V69eneO1IYoi1q9fL6W+yBqAdnFxQXh4OKKiorB+/foc94coili7dq10Y6RZs2bSZ2NhtFEYbGxspHQIYWFhOX4O63Q6rF27FqIowtnZWRow08bGJluOckuTnZ0dYLw55u7uDi8vL5QsWVLKoxwQECAF7k1PgeRXu7JPBrnI+pqIiIiIXk0eJZzw6x/90Ot/jZCYkIotG89hy8ZzSExIRa//NcKchQNRngOhET0XuQZ3XyamQcpetMOHD2Ps2LGAcZ1mzpwJNzc3eTUAQOfOnTF48GDAOBjP4MGDpcdqC8v58+elgEDdunWlNANPo9Fo8NZbT3IyLVq0KMcA4suoSZMmUl5jAPj222+lQadM7OzsLPamzapEiRKYMGECFAoFDAYD3n77bSxYsMBiigUY0yj89ttv+OqrrwBjEHbSpEk5ngO3bt3ChAkTsu1bURSxbt06KYD6xhtvoFatJ6OyZu3Rd/ToUezcuVMqM3n06BGGDh2K+Ph4eZHE1tYWY8eOla6fL774At9991229TERRREbN25E3759YTBkDmbwzTffFEpqlvbt26NUqVKA8WbH7t275VUA4z779ttvodfrERMTg6CgIHmVAtNoNPJZgHH7TQNLBQUFFdmj7/nx5ptvSvv/p59+yjEwfv78eSlXqrOzs9RbUy4mJka6VipVqgQXl7zlv7KxscGAAQMgCAIiIyMxYcKEbClBUlNTMWHCBKn9/v37m90g1Gg06N+/P2A8zqNHj87WBgDs2bMHM2fOBCzcPGnVqpWUnmLmzJlSL+GsRFHE6tWrpQHmqlatavZ5VxhtFJasx3fcuHG4evWqWbn8s6Jt27bPnDohJ+XLl5e27dSpU/LiPFnfNQiTm1fH5ObVsb5r4V+zREREZFnFSt7YdXAM8+3SC+Hm7ohPhrXBxh0jcejkeBw6OR4bd4zEJ8PawM2/hdvNAAD/9ElEQVQ97+PZEFHBFJvgrrf3kxE+ly1bhps3byI6OjrHIFxRyJpnFwC+/vprqferJSqVChMnTpRSCFy7dg3ffvsttFqtvCoA4ODBgxg8eHCepvnz5wPG3pim9XnrrbdyDFpZ0r59eynot2/fPpw/f15e5aUlCAI+//xzKcgRGRmJ0aNHm+1bZ2dnKTh3+PBh7N+/32Jey65du+LXX3+FQqGAVqvFgAEDULZsWXz//fe4dOkSIiMjcfPmTfzxxx+oWLEihg0bJg3oNn/+/Kc+1j5//nw0btwYGzZsQGRkJI4ePYp+/fqhd+/eMBgMcHZ2xrfffmsWBLOxsUGfPn0AY27ht99+G+PGjcPNmzcRFhaG2bNno3bt2ti3b99Tg3OBgYFYtmyZ1Htv3Lhx8PT0xBdffIGTJ08iMjISYWFhWLZsGRo0aIDOnTtLPSbHjRuHAQMGyFp8NmXKlJH2c3x8PFq3bo2+ffvi6NGjiIyMxKVLlzBmzBhUqVJFSnny4YcfSr1FC5MpyJySkoKVK1ciLCwMsbGxePjwIU6ePAlBEKS81C+ar68vJk+eDIVCgZSUFHTo0AF9+/bF7t27pf02dOhQBAYGSj1hv/rqK7ObBVnFxsbi8ePHgCzdQV507NgRAwcOBIyfwy1btsT27dsRGRmJ7du3o3nz5tJn06BBgywOdjhgwAAMGjQIMLZRpUoVzJ49W+qFHhwcjDZt2iA+Pt7izZOAgABMmDABMKZladGiBYYOHSpdq7t370aXLl3Mrq8FCxagRIkShdpGYcl6fENCQtCwYUP8/vvv0v4YPHiwtB7ly5fH5MmTsz0pUlAqlQpjx46VUgk9CyulAl80rIQvGlaClbLYfLUgIiIiIiIq9orNL7Bq1apJAZnVq1ejQoUKCAgIyPWR9MIkz7Pbo0ePPOUjdnNzw4wZM6TA2ty5c7Fq1Sp5NcCY23Hu3Ll5mg4ePIgHDx5IaQScnZ2lx+/zqlKlSmjbti1g7Fm8fPnyHB9Pfhn5+Phg4sSJUuqLtWvXmqVV8PDwkB5pDwkJQcuWLeHt7Y0//vhDqgNjoPiTTz7B3r17pfygd+/exZgxY1C9enV4e3ujQoUK+Pjjj6VH0cuVK4e9e/eif//+uT6m3aRJE9SrVw/nz59Hly5d4O3tjaCgICmVRrly5bBv3z7Url1bvii6deuGIUOGAMbz79tvv0WFChXg7++PIUOG4O7duxgxYgQ+++wz+aLZtG/fHufOnZNuNCQkJGDq1Klo0KABvL294e/vjz59+ki99pycnLBy5UqMGzdOenS7MHTt2hUbNmyQBqhaunQpgoKC4O3tjerVq+P777+XAvSjRo3CtGnTct2/z6pZs2ZSWo/p06fD398fgYGBOHbsGG7cuIHKlSvnGBx9ESztt9atW0v7bcaMGdBqtVCr1fj9998xcuTIHPdbVFSU1HM7v8E8Uz5xU6/gU6dOoV27dvD29ka7du2k86dt27aYNm2axXNHrVbj559/lnrw3r17F0OGDIG/vz+qV6+OmTNnwmAwwM7ODhs3brSY8mTgwIFYvHgx1Go1tFotZsyYIV2rrVu3xsaNGwHj9bVz506LN+EKo43CkvUGU0JCAj755BNpf8ybNw8A4O7ujr/++gt+fn7yxQuFn58fdu7cKQ3CSERERERERMVDsQnulilTBjt37kTTpk2leYmJidJgVEVJlOXZLV++fI6BC0sCAwMxbtw46f/BwcFSzsuC2LdvnzT4TePGjfP96Lx8wK3169cjNDRUXu2l1qlTJ/Tt2xcwHqfRo0dL26BSqTB79mx8+umnZsfqypUr0uusmjdvjqtXr2LHjh3o168fSpcubVbu4uKCbt26YePGjbh69SqaN29uVm5J2bJlcfDgQfz2229mvc+rVq2KxYsX4+LFizkGEdVqtZTbt02bNtI2qNVqtG/fHsePH8/XeVihQgUcO3YMx48fx9ChQxEQECAFxgHA3t4erVq1wtKlSxEREYFevXrlGCB8VoIgoEOHDggPD8f06dNRvXp1s3Xw9vbGBx98gGvXrmHKlCl53rb8atiwIbZt2yblW4YxXcHevXuh0+kQGBiYY6qNF+Fp+61cuXIYM2YMwsLC8NFHH+V63LL2bvfx8TErywuNRoOVK1di8+bNCAoKktZDoVAgKCgImzdvxubNm3N9isDe3h4LFizAhQsX0KtXL7Pew35+fvjmm29w9+7dHG9YCYKAvn37IiwsDGPGjDEbtE2hUCAwMBDLly/HxYsXpRsacoXRRmEx3WC6cuUK+vfvn21/TJkyBbdv30a9evXMlits/v7+OHTokHw2ERERERERvcQEsTh11aRiq1+/fliyZAlcXV1x+PBhVK5cWV7llXHmzBk0a9YMKSkp6Nu37zMP9kZEVJzMPJuZa/mTOgGIi4uTFxPRf8T1O/Oh02fPp05ERET0KnDVVIeXR/7SOCYnJ0OtVktPEBe2YtNzl14NDg4OL83geEREVDi6/H0EwbvOIXjXOXT5+4i8mIiIiIiIiIoIg7tU5NLS0vDw4UPAmCoh6+PkRERU/G24+SRFUtbXREREREREVLQYZaMic+XKFdy7dw+rVq3C7t27AeMgbrnl4iQiouKnfTkvi6+JiIiIiIioaL3ywd0pU6ZAEIQCT/369ZM3TblIS0vDyJEj4evri/79+yMjIwMKhQIffPABbGxs5NWJiKgY29A1CNNb1sT0ljWxoWuQvJiIiIiIiIiKyCsf3KUXIyEhARERTx7NrVixIlavXo2OHTua1SMiouJPqRAwIrACRgRWgFIhyIuJiIiIiIioiAiiKIrymURERETPKi4uTj6LiP4jrt+ZD50+WT6biIiI6JXgqqkOL4+W8tm5Sk5OhlqthpWVlbyoULDnLhEREREREREREVExxOAuERERFdgvp2/gl9M35LOJ6D9GRN4eChTyVo3yiLuTiIjov4tpGYiIiKhAOq07jE23IgEAHQO8sPi1qvIqRPQfoM9IROqNrdB6VwRUOTx2qNfiVuIVZNjYQVRzkN2C0Bv0UAgKGAQRJQQNHomJDPISEREVIQMMKKHVQWXliADnFvLiHBV1WgYGd4mIiKhAhClrzP7/aPDrZv8nov+OlDMLYFujNwS1nbxIcuHOclzX8AHCwtTSsRXCMkIRkn5bXkRERESFxNuqFMrBDQkZkajomvffPEUd3OW3KiIiIiqQjgHeFl8TERERERFR0WJwl4iIiApkY7cg/NyqFn5uVQsbuwXJi4mIiIiIiKiIMLhLREREBTasXnkMq1dePpuIiIiIiIiKEIO7RERERERERERERMUQg7tERERERERERERExZAgiqIon0lERESUV3qDiF9O3wSM6RkSHsfLqxDRf0TKmQWwrdEbgtpOXiS5cGc5rmuKpo/J3Ut3Mf3tn1GnbW30n9ZXXpxNTvWT45Ox7Y8deGNwazi5O5ot8zJq6dgKYRmhCEm/LS/Kl8fRjzGl+3Q8ingkL8rRWyM7o80Hb2DRyCU4u/0cPl8xHKWrlZZXe+5yOraW5He7Tdtc3IX+G4ZL+y+jw9A35UUvzI7ZO/HPtA3y2Ra5ervii7WfQ1NCIy8qVrRpWuxesAfVW1aHT+VS8uJCdWLDSfwzdQNG/DUMHn4e8uIiZzq+H80ZghqvVQcAiKKIS/suIz46Hk17N5EvYtGDO9FYOno5bp26BQCo3rIaBv02AFa2VvKqlAeLRi7B8X9OmB2XZ2HpMyU/n8WFJTk+GTPe/x1l65RFjzHdIAiCvMoz87YqhXJwQ0JGJCq6vi4vzlFycjLUajWsrIrmHC2ab1VERET0n9H57yP4fN8FfL7vAjr/fUReTERUrKSnZmDu0Pk4u/0cRINBXvxKUygVcPV2gYvXk8m5pDMEReYPYzuNnVmZi5cLbB1t5c0Ua46uDtm2UT69Ctt87+p9/NLnNzwMfygveikoVcps+10+uXq7QKEs/iGNlRNWY9ufO2DQ6+VFhSriZiTWTFyHln2bv5DAbk4uH7yC3wf/iZTHKfIii1ITU7Ho88W4deoWfCv7oEmvINR6oyYDuy/Yy/SZYu9sj86fdcLhVUdwce9FefErqfh/EhIREdELteV2pMXXREQvu9LVSuO3iz+b9yYSRRh0/62gromjmyM+X/kZfjg0UZom7PoG5esFAAD6T+trVvbDoYlo9r+m8maKLStbKwQv+DjbNsqnV2Gb/8/efYdHUX0NHP/O9mx6ISG00HsTpIOiiBQRBKXYwAZiwd6xYEMRFBX0VX8WRFERsFAElCK9994TAgkE0uvWef9IdshuNrQkKHg+zzO6zL0ze3fm7iQ5c+dct8uF+19886LNzVeXOO6+yzM/PUVw5L9/ZP25uJwVG9Sl6D3mffwHIVHBdLytg2/xJdPjwRv57OAnXqND3a4L64e5GbmkJqUR16wGT37/OHe9fQedBnb0rSYusdKuKX5/zl4C9dvXo1nXpvzxyQJyM3J9i684EtwVQgghRJn0q1fF72shhBBCCPHP27/2AJsXbuGaO7pc9gHxvMw8CnIKMAeY0Rv1vsVCQNHo/44DO5Cw8yib/tjsW3zFkZy7QgghhCizSZsKc+6Oal2P9PR032IhxH/EvynnbpchnZnx9kzityeg6BSaXNOYW1/oT2zdWL/17xk/lO1LdvDpiM+89nk55PUsr5y7/tjybHzywP+xf/2Bs+Zj9OTcfXzKKI7uOsqCz/4kMyUTvUFPhwHt6PvkzYRUCvGqX1qOR895aN+/nTbaa+HnfzJv8nye+v5xTh09za8TfifteJp2bgeOvo2YWtHaPnzPrUf8tng+vncy+TkFjPj4Aa7q2VLLuZuTnnNBOYM3ztvEV098Q2S1SJ787jEiq0VqZXtW7uXj+yZTqUYlnvh2FBFVI6Aov+qm+ZuZ/38LOXn4JADVG1XjtpcGUL99/RK5Id0uN1sWbmXepD9IOlD4dExE1Qj6P9OP1r1bodPrtHMUvyPBb/s9eU6L50de++s6rzrFj3Vp79n9/m50HtQJo8Wobec5zj1H9sBgMvD7+7Nxu92079+OIWMGY77AR+U9bS3ennPJz87nkxGfcXDDQW55pi89HrxRO46OAgdfPPYlO5bsZMBzt9B9+A1a2cnDJ1nw+Z+s/30DLqcLc4CJToM70feJPliCLD7vAskHk/ll3G/sXrGn1Pr++q6Hp58BPD/zGTJTMplw+0Ts+XatjinA5HUOM1My+evLxayeuZq8rHz0Bj1X9WzJrS8MILxymLbd2bhdbr5+egr71x3guZ+fJqp6FA6bg88f+R/71x0o0Wfityfw/h0TMVlMPDv9KSrXqayV7Vm1l4/vmUzXu69l8KsDASjIKeCPTxew7rf1ZKZkAmAOMNHm5qu56bGbvNpZPOdug/b1tWtLcaXltC5+LSqu+DX6UvXd9BMZzPt4HhvmbMSWbz/refFXt3GXRgx4/havn0cUXRf3rzvAU9MeZ+2v61kyZQl5WflYQwLo9Ugvrh/Wldz0XGZPnMOaX9bhcrqoUi+Wwa8O9Lp+eK6XT0x9jKM7jzJv0h9kp+VgDQmg85DO9H64p1cfL+167Ha52bV8N398soAjW4+An+tP8e2L83wHSrsWX8y5uunRXjTv1owfXvmJg5sOobpVYmrHcMtTfWnZo0WJ62deZh4f3TMJnU7HY1MeLZeUOpJzVwghhBBXrFGt6zGqdT3f1UII8Y/YtWwXH90ziazT2XQc2IFqDaqy8+9dvNP/PXav2ONbXRMeG06HW9sTFB6EwWigbd82tLn5aoyWivlj7ErjKHDw6YOf8fNbM6lcO4b2/dthDQlg5c+r+XDox2SnZvtuckFUVeXX8b/z9VNTiKoaSceBHYiqFsnOv3cxftD7WoCgNAk7jjL5gf/zCuyWxVU9WtJ+QDtOJ55m4Rd/aY+Xp5/IYPqbPwPQ/5l+WmA3Pzufb575linPTiUtKY3WvVvRuncrThw5ycS7P+av/y2i+Ngrl9PFD6/+xP8e+4qTR1JocUNzWvduRXZqNl89+Q0/vPrTRT3SX69tXVp2LwyERFaLpPPgTtRrW5h6w5Zv59vnvuN/j33FqcTTtO7divb92+HItzP9jRl8MuL/yM/O990li75ezG8TfqfJtU1odl0zomtGYw4wsfDzPxlZ9xGmPDvVq35p6y9GQHAAt77QH1OAiT+/+Itje45pZX9PW86OJTtp2LEBXW7vrAV/dizdydv93mXNrLVUqRdL58GdiK4ZzZIpSxl7yzjSjntPrrdlwVbe6P02O5bupFqjqnQc2IGQSiGl1j8fgWGBdB7UkchqkSiKQsvuLeg8qCOBYYEAJO45xru3jmfR14sJDAuk48AO1GlVm41zN/Fm77eI3xbvu0u/UuJT2LNyLzWaVNduUhnNRuq1qYs9307CjqNe9RN3JeIocJCbkUtKfIpX2f61+wFo2rUJAGnH0xh7yzj+/OIvQiKD6Ty4E+37t8MSXPi9/2jYx6Sf8D/Rrk6vo+l1TWnUuSEUPb7feXAnqjb0P6mcp76n7waFB9Hh1vbaNbq8+25p4rfF82bvt1j582rCYsPpOLADVerF+j0ve1bt1epG14ym8+BO1GlVmx1Ld/JG77fZsmCr174BctJzmPzA/zFv8h/UbVOPVj2vwpZnZ9Y7v/DzmzOYcPtEti3aTpubr6Z+23okHUhm0n2fcmSrd39QVZWFn//J9DdmEBYTRseBHbAEBfDnF3/x3sAJ5+yzLqeLX9/7jU+G/x/x2+Np3q25djx9rz9nu6b4c7Hnas+qvbzT/z0yTmbQ4db21G9bj5OHT/L5o/9jxY8rfatjDbXSuHMj4rcncHhLYXD6SiXBXSGEEEIIIcQVJTsthzZ9rub1P19l6Dt38dLsF3jgo/tw2BzMHDur1CBj9UbVGPLaIKrUiyWkUgi3vtCf/s/2wxpS9tE+/wWqqmK2mnll3mie/P5x7hk/lDcWj6Fum7okHUhm35rCwNDFchQ4SNydyDPTn+KpH55g6Dt3MebPV+lwa3ty0nPOOnHOqYRTfPXkN+Rm5pZLYJeix377jOpNVPUoVvy0kn1r9uN2uZn/6QJOHDpJlyGdadG9uVZ/+Q8r2LxgC3Xb1OXtpW8w/OP7Gf7x/by+8FWia0bz64Tf2btqn1Z/y4KtrJy+irpt6vLOird46LMHveqvmrGaQ5sOa/XPV6eBHen9SE+MlsIAX/Gcpct/WMG639fToH193l35NsM/vp97xg/lrWVv0q5fW/au3secj+Z5BaEB8rPyue+De3j48wd5+PMH6fVQD6/yilazeRy9H+lJXlY+sz+cq/WV+Z/MxxpSGPz1jNpLT07n57dm4rA5GDbubl6a/QJ3vX0HL81+gQHP3UJKfAo/vj4dR4EDgNPHUpn+5gwCgiw8/u0oXvz1ea3vdR7ciZT4FJZ8+7dPi84tslokg14ZSL02dTFajPR+pCeDXhlIZLVIbHk2Zrw5k/TkdPo81psxRdeyp354ggcnDyc/p4CpL3xf6rWsuKM7E8nNyKVO6zoYzWdGQ9ZpXQedXseBDQe186mqKgc2HNSC4MUDhrY8G4c2HSa8SjhVG1RBVVX+/N8iUuJTGPDcLdpxvGf8UMYuf5MOt7bnxKGTHFjnPdLWw2g2cuPwG7huaFcAWve6irvevoOm1xYGjn156nv6bpV6sQx5bZB2jb4UfTc/O5/pb8wgP6eAYePuZszCV7SfMQOeu4W8rHwWfPYnDpuj8CbPGz+Tn1PAAx/dpx2fp354gse/HUVAkIVpr/xY4qaUPd+O2+Xmjb9e4+HPH2TE5Ad4+IuRKIrCsmkrqNOqNmOXv8U944fy5LTHueWZvjgdTjbM2ei1H0eBg+1LdjD03cL2DX3nLl7/81U6DexI0oFkFn2zpMSxKG774h0s+noJUdWjeG3+yzz8+YPaNb1hxwasnL6KVT+vhnNcU/y52HO1d/U+ej/Sk9f/ek37Pjzw0X0oisLGuZuw5dm86lN000BVVXb+vcu36IoiwV0hhBBCCCHEFaVynRj6Pd1Xe6xTURRa9byK1r1bkXQgWXu8VJS/G4ffQJV6Zx41DggOoF3fNgCkHk8tVvPitOjegtpX1dL+rTfouapHYaA2+eCJYjXPSDuexqT7P+XU0VPcP/FeWvZo4VsFioIqY28Zx8i6j5S6fHDHh14BhMhqkdzyTF9Q4fcPZrP2l3Ws+GklVerFcvPjN2mPLWedzmbNL2sxGA0MeO4WrxQVEVUjCh9xV2HVzNW4XW4cNgdrf1uHTq/j5sdvKlG/z6jeBAQHkLDTe9RlWeSk57B6xhpMASZufaG/NoKUosfsb36yD2ExYWyct4nUY97nsnKdyjTs2MBrHcUm0PJNUVDaeo+1v64rcex9l+3FgvmKonDtnddQt01ddizZyaqZq/nlvd/Iy8rnlmf6Ub1xda3utkXbOZVwita9W9HulrZaIFNRFK4b2pXGXRqxZ8VeEotGAO9ZsYeMkxl0HtLZ6zPqDXp6jryRSjWiOJ14moKcAq2srA5uPMSBDQep1aIm3e69Hr3hTG7Zlj1acO0d15z3DZMjRSNJqzbwnhehcp0YYmrFkLj7mDbhVG5GLom7j1H36jpUbViVQ5sOa/099Vgqx/YeI65ZDUKiQsjNyOX4vuPE1I6hbb82Xo/E6w16KteOASD9RMWn66qIvutP/PYE4rcn0KJbc6/PrChK4ajc1rXJSMkkJz2XnUt3cuLQSTre2oFWPa/yOj6NOjXk+nuuIyc9hw2zNxR7h0KdBnagUlwl7d/VGlYlPDYcnV5Hp0EdvX621b26Ljq9zu9IV98+brQY6fVIT8Jiwtj61zYtjYYvR4GDFT+tRFVVbnm6r1dqjoDgAAa9fBvWkADWzFrr933PpiznKqJqBO37t9OuqwD12tQlPDac08dS/X4HI6tGYAowkXwwGVuxFChXGgnuCiGEEKJMHG6V99bt4711+3C4Sx8BIIQQl0qNJjUIjgzyWqfT62jQrj74jEYT5SuqRpTvKkJjCh8F9/eH94Wq2qCKV5AEICw6FFMpj1GfTjzN54/+j5T4FDoP7kTr3q1KbF9ccEQQ4bHhpS4hlULAZ3tPeob47Ql899I09Ho9d7wxxGviqpT4FFLiT1GzRRxV6nvn2QSoUr8KodGhxG9LICc9l7zMPJL2JxNRJaJEXk6Atv3a8MGm8XS/v5tv0UXLOJFJenI6NZpUJ7pY/mKPyKoR1G9Xj6xTWZwoyheslVWLwGw1e60rC71BX+LY+y7FR6Hik55h+usz2LNyL617taLToDMjCN0uN/vWFQZEW/c6kzPUw2gxUq9tPZwOJ4e3FI6KPrS58P/129Ur0Xeiqkfx5pLXGfl/I/zm6b1YB4tG0za/oXmJPKGKotCgQ+G1bM+qvV5lvhw2B6eOnsJoMRISdeYGAUVpIWq2iOPkkZOcOFR4PtOOp5F6PJXmNzSnRpPqnDhyUhsdfGjTYfKy8rmqe0t0eh1B4UE8/eOTvP7nqwRHBpNxMoPdK/fw5xd/8eHQj5k9ca7X+1WkS9V3j+48iqqq1G1TxyvgTlEKgGenP80Ls54lvHIY8dsToCiFhW8/K1zfFKPFyJFt8SWCjrHFbpAVFxYTRiWfa6zJYsRgMnit8/DXxyNiw6l1VU0yTmRwOtH/DbfMU5kc23ucsJgw6lxdx7eYiKoRVKlfheSDJ0g9dvb0Dr7Kcq6i4yoREOKd098SZCGqWL5zX6HRoQSFB3E6MRVbbsmRvVeKkj1MCCGEEOIC3PLLKp7/ezvP/72dW35Z5VsshBCXXGy92BJBGIoFGUvLASnKxhRgIqRYQNNXeRz36JolgwFnc3DjIRJ2HEVRFLYs2Frq6F6K2j/q60d4Z8VbpS4PfHRfiXyceoOemx7tRUSVCFRVpevd15YIiORl5eF2uUk9nsbMt2fx/egfvJbZE+fgtDvJOp1Fxol0MlMyyUnPISI2HLPVf+C6vGWnZWPLsxFRxX+wS1EUIovyB+dl5HmVBYUHlQi2lkWbm68ucex9l0adCnO1FlezeRw9HrwRVVUJDAvk5if6eAXhHDYHuWm5KIrC6llrSpyH70f/wP6i4O+xPcex5dlIO552zr5d3jzflT0r9pRo3/ejf2Ddb+tRFIVTCaf8Poru4Xa5sefZURQFnc77mqgoClfd2AK3y0389sIbXgk7juK0O6nVoia1W9bSAmxul5vdq/YQGBZIjaZnRkHb8u38/NZMHm38OC90Gs3H90zml/d+48iWIwRFeN9gq0iXqu96rh/nug45bA4yT2VitBgJjw33LQYgKDyQgOAA8rPzcdmd2npTgImwcpjAs7T31ul1mAPMuF1u8rK8j4VHXmYeBTkF2PNtzP5gTon+N3PsL6SfyKAgt4C05AsL7pblXIXFhJW4/p4vt8uN6i7Mi34lkuCuEEIIIcrkj0NncoUVfy2EEEL8Gwx4vj99n7qZnPQc5n38x0VNQnYuibuOaY+fb1+yg/Qk/4+ipyens/Ln1aycvsprWTNrLTnpOb7V/7WKz2T/b1KQU8CB9YU5XnMzctm+eHuJvJ0U5ZbdsWRnifOwcvoq9qw8+2jYS2n/+gMl2rdy+iq2/rXN7+e6UFXqVyGkUgiHNh/G5XRxeOsRIqpEEF0zmupNqmO0GDm08RBZp7NI2HGUqvWraAFDR4GDLx//iiVTlhJbpzK3vz6YV/8YzQebJ/DRjolcP6wwl+6/zb+t7xrNRnSG8g/N+QvoX6i8rHzW/LK2RP9bOX1ViZQJFeHfdq7+zcq/BwkhhBDiP6V//TOzGhd/LYQQ/5STPo9yemSeLMwvGFOUC1Jc+dr2bUP3B7pxze2dqdk8jk3zN7Nx3ibfamWSdjyNXyf8jtFspMUNzUmJT+GPTxd4BZGtIVZ0eh1t+7bh/w5M5rODn/hdPt4xkRpNaxAcFYI1xEpacjq2vEuTJzI4Ihiz1UxaUprf0aBul5sTRwq/W//GoIuqqqz4cSV7V++jftt6RFSJYMFnC0koejyeokBaYEQgRouRF355rsTxL77cM35o4SjKymHY8+1kncfkZeUlvHIYAA9+MrxEu4ovT/3whN/Rjx46vQ6T1YSqqrj9pM4KjQ6lRpPqJOw4SvKBE8RvS6B6k2oEhQcSVT2SqGpRJO45VnjzIimdJtc21t4vYedRdi3bTfXG1Xn6xye59s5rqFK/yj8yAeWl6ruenx0p8Sm+RV6MZiOhlUJxFDhIT/Z/oyfjZCY5aTkYTAYUXfmH5uz5djL85NR12BzkpOeg0+uw+qQ48LCGWrEEWah7dR0+3PZ+iX5XfGl+fTPfzc/qUp0rXzq9rkKO87/FlfvJhBBCCHFJ/NK/I5/e2IpPb2zFL/1LnxlXCCEuleP7krQJgjxcThe7lu9GURRqNo/zKhOXXmzdwgl6PAH34o5sKb8J73R6HYqiEBgWSL+n+qLoFOZ8OK/cRp25nC7++HQBKfEp3HDf9Qx99y5qNo9j1YzVbF90ZsKvqOqRhFUO4+iuo2SnnnuEblBYIFXqx5KWlEbywZJPxRzceIinWj/Lj69Nx2g2ElE1AkeBg9xM78eY3S438TvOBDfPJqxyKOGx4RzdlUjKkZLBq7TkdI5siScwLPCsOS7/KQnbE1jw2UJCo0O5+5076ff0zeRl5fPb+7O1SZ90eh01m8XhKHAQXzTR2NkoiqJN4Ld/3YESo2XzMvMYd+t43uk/jvQTGVr+55z0HBw2h1fdU0dPk3Hy/FKT1GpZE4B9a/aXeM8LYTQbqVSjEo4CB1mns3yLMZqN1GtTl/SkdHb8vZPU46nUbBanBf4q143h6K5EVs9ag6JTqNP6TLoRT6qRqg2qeE2KRdGo3n1rzz3ZW3m5VH23RpPqKIrCwQ2HSjwB4Ha5+d+orxjd9VWO7jyq/ZzZ+fcu3K6S6QD2rt6L2+Umtm7sRacaOJfdy/eU6D+Zp7I4tuc4MbViqFzH/43O4MhgKteKIWl/EmnHLyztwrlcqnPl4UlxE1U9EnNg6TdCLncS3BVCCCFEmT10VR0euqrkhAtCCPFPSNydyIqfVmp/UKuqyvrfN7Bt8XbqtamrBWvOxpZnu+BZwMX5i6xa+Ef7pvmbvY7zoc2HWTZtebGa5adBh/p0vLUDpxNPs/CLv/wGXC7U9kU7WDVjNVXqxXLd0K5aEBng1wm/a4GR0OhQ2t/SlhOHTvL7+7NxFJwJ/KmqysZ5m3i4wSg+uPNDcjNyMVqMtLulLW6Xmz8mz/e6WWHLt7Pw8z/Jy8yjXpu66PQ6oqpHoqoqm/7YrAWdVFVl84ItbP1zm7atr+JByKDwIDoO7IA9386sd38t8Z5zJs4l42QGrXu3OmfO0UstP7swiJuXlc9Nj/aiUlwlWvW4imbXN2Xv6n2s+HGlFuRq1fMqQqNDmT1xjjZZmkduRi4T7/qIRxo+xoY5GwFock1jQqNDWfnTSg4Xu/Ggqiob5m7kyLZ4KsVVIrRSCNawQKwhVg5vPszxvce99jvnw7ml9jnfwHyd1nWo2TyOZT8sZ9Mfm70CdI4CB9+9OI2H6j/K3I//0NaXplaLwkBxaTdN6rSug6JTWPDpApx2pxbA1el1NO7UiKxTWWz9c1uJYKBnNPrBjYe8bpa4XW4WT1nC7hV7tHXnI8PPjZ7zdan6bs3mccQ1q8G2xdvZsnCr13nZtXw32xZtJyQymEpxlWh6XVMq14lh9aw1bF6wxavunlV7WTJlKdaQADre1l5bX97Wz17P3tX7tH87ChzM/2QBGSczaN+/LUHh/vMim61mOg3uSF5WPtPfmFHiZumhzYd5qtUzjOnxBqcSTnmV+buxUdylOlceqcfTsOfbqdWiZoUF0f8N9GPGjBnju1IIIYQQ4mIVFJR9NnQhxOXJkbwFY0xTFH3pj1KezNhBqqVseQBLk5mSyepZa6lcJ4aN8zax/vf1HNt7nHkf/cHyH1cSFBbEPROGEVk0GshTP7ZuLC1vbAFFk2Md2HCQw1uOkLAzkZT4FGo0qY7JYmLKs1P57OEvMJqN1C02YdbCz/9kwpAPOJ2Yqu3nbOvLWy1zbTJdGaS7/D/+WxYuh4sNszeSejyNNjdfTUwt/yO9tv61jeSDJ+h4WwdCfSYDOnkkhQ1zNlKtUTXtOARHBrNr2S6ObItn9YzVJO1PZsmUpcz5cC712tajIKeA6JrRWv1Dmw6xd/U+v23wdx79rVN0ClXqxbJ5wRYOrD9AXNPqxNSOwZZrY9XPq8lJz2Hrn9tYOnUZi79ZUuoSvy2eJl2bkJWSyZTnppKXmcd9799D1YaFqYkiq0aQnZrN7hV7sOXbaXJtY3R6HTWa1uDYnmNsmr+ZZd8v4+iuRHYt382v7/3Oih9XotPruPWF/tRoWgOKRjdnncpiy8KtLJ36N8f3HmfHkp1MHzOdxD3H6Dy4E92H34BOpyMoIphNf2zm4MZDbJi9Qev3y75fTvNuzUhLSqN+u3pavy0MTG7i6K5Ekg+dID8rnxpNqlOtUTXSk9LZtnh7ifc8vPUIddvU5fYxg7EEWUo9zsWV9h0obb3nPCcfPMHK6atKHHvfpWqDKkRVj2Tpt3+z4qdVNLu+Kbc81Re9QY/eoCe2TmU2zdvEwY2HaNihAWGVw7QRgRvmbGTl9FXsWLqThB1HWTVjNdNe/pFTR0/TuHMjej7UozCNQ1ggUVUjWff7elb9vJr9aw9waPNh5n44j1Uz1hBdM5p7xw/DGmolIDiA9OR09q87wKqZazi6M5Gtf21j2is/AgrRcZWw5dnoPKgjlsDCY3jyyEn2rNpL/PYEUo6kEFMrhrCYMKo3rs6WhVtZ++s6Ns7bROLuRNbP3sCPr/3E4S1HqFI3lltf6E/AOdIg6I16NszdBEphYLv4BHMAlkALe1btJfV4GuFVwun+QDft/DpsDtb+ug7VrXJ179Zc1bOlNlllYFgg8dvjObrzKCt+Wqn1lR9e+ZFdy3bTYUB7ju9LIrRS6Fm/x7ZcG+t+30Di7kROHD6JXq87a+ocT58Liwmjbd82GIwGgHLvu/6YLCaqNazK5vmbWff7BrYv2s7RXYks+moxCz5biCXIzP0T76VSjUoEBFmIqR3D9sXbS9T945P5OB0u7nh9CI2vaawd09Kuo55rFODVdyjlc3iOsynAzIqfVrJ/7QH2rd3PT2Omc3DTIZpc05j+z92iTSS39a9tHNt73Ou8xNatTE5aDpvnb2Hx10s4uvMo+9bsZ97kP/jjkwU4bE56jOxB065NUBSl1GtK1qmsEu0rz3Pl+RmVl5Vf4tgALP9hBQnbE+j9aC8q1YjyKrsYwfoQIrBic+UQFVDbt7hUDocDvV6PXu/9/SsvMnJXCCGEEEIIcUVp27cNT37/OEazkdUz1pB0IJnOgzoyes6LxDUrDJyVRlEUeo7sQa2WtTiy9QhLp/7NqaOnfauJMgqODOaJqY/ReVDh6LC1v67jZHwKQ14bxF1v3YHRUjEjrCrFVeKmR3uhulV+mzCbbJ88qtlpOaQnp591yTqVhcvh1NIxdBnSmQYd6mv70Ol19BjRnajqUV7pGQKCAxj56QgGvzoQS3AAm/7YzOoZa8hITqfDre159Y/RXNWjpbYfvUHPHW8MYfjH9xNRJYJNf2xm7a/rMAaYGDbubm5/fbAWqKtSL5Znf36apl2bcPpYKqtnrMFuczBi0gP0fqSXFjzyCI0OZcBztxAUFsiWBVtZ99t6bHk2zAEmhr13N8M/vp9K1aO83nPwqwN5/JtHCakU4rWviuByukocd3+Lw+bQ0jFYQwLo+0Qfrzyd1RpV48YR3UukZ7iqR0tenvsiza5ryrE9x1k5fRXbFm0nplY094wfyshPRxAQfCZoelXPlrz6x2iadm3CgY0HWTl9FSnxKXS961qem/E0EVUjoOj6cesL/Rny2iBCokLYvng7W//cRru+bXjmxyf9Bi07DGhP616tOHn4JEun/s2xPccAiGtWg9FzXqTzoI5kJKezesYaNv2xGUtwAAOeu4XnZjyjve/ZRNeMplHnhsRvi+fUUe9RlhTlV619VWGQKq5ZDUKizpzf6JrRRFQpfI/GXRp69SOz1cyDk4dz/T3X4Xa62fTHZjbM2UjT65ry6h+j6f1IT4Iigkg+mEyeT7qQ4qo2rMrNT/RBVVXW/baejfO8Ryqfr0vVd2u2qMkrf7xM50EdSTqQzMrpqzi0+TCtel7Fi7+9QM2ikdIAjTo11OqmxKdodZsVHaMOt7Yv8d0sT7e/PpgBz/fn2N5jrP11HTqDnsGvDizRv/3RG/QMeW0Qwz++n5ha0WxbtJ2V01dxbM9xml3XlOdnPcsN912vtb+0a4o/l+pc5aTnsHf1Pmo2j7vi0zEp6sV8a4QQQgghihQ4XXywoTCv2lNt6pOfXTKnmxDivyFv09cENB+CYvQ/SQvAtiPT2BcqY0zK03XB3Uiwx3PYdsi3SAghOLjhIB8OnUTPh26kz2M3+RaLK8zCz//k1/G/8/AXIy94wrMrya7lu5n8wKfcPmYw19zRxbf4olQxVaUOkWTZk2kQcYNvcalyc3MxGo2YTBVz41J+qxJCCCFEmfT/dTWjl+9k9PKd9P+18JExIYQQQgjx71Drqlq0uKE5G+dtIuu092h1Ia5EjgIHf3+/jLimNWjdu5Vv8RVHgrtCCCGEKJMFh0/4fS2EEEIIIf55eoOemx7rTU5aLpvnb/YtFuKKc3DjIfat2U/vR3oSGBboW3zFkeCuEEIIIcrk1gbV/L4WQgghhBD/DlXqxdL3yT4s+OxPTiWUzL0rxJUiPzufeZ/Mp33/djTt2tS3+IokOXeFEEIIUWZfbD0MwIiWtUlPL//Z2oUQlwfJufvPkJy7QgghRMWTnLtCCCGEuGKNaFmbES0LZ1oWQgghhBBCCHFpSHBXCCGEEEIIIYQQQgghLkMS3BVCCCGEEEIIIYQQQojLkAR3hRBCCFEmuQ4nb67ezZurd5PrcPoWCyGEEEIIIYSoIBLcFUIIIUSZDPh1Da+u2MWrK3Yx4Nc1vsVCCCGEEEIIISqIBHeFEEIIUSZ/Hjnh97UQQgghhBBCiIolwV0hhBBClMnAhtX8vhZCCCGEEEIIUbEkuCuEEEKIMvm5Xwe+6tWGr3q14ed+HXyLhRBCCCGEEEJUEAnuCiGEEKLM7mtek/ua1/RdLYQQQgghhBCiAklwVwghhBBCCFEuFN8VQgghhBCiQklwVwghhBBCCCGEEEIIIS5Diqqqqu9KIYQQQojzlWVz8MGGAwA81aYerrwc3ypCiP+Igp0zMDe4CcVo9S0CwKW6WJU8gzxLAIreBIqKUTGhkzEnF8zpdKLX63HjppGlEYn2RPLVfN9qQgghhCgnOnTUVIPIdaTSKLKHb3GpcnNzMRqNmEwm36JyIcFdIYQQQpRJ9+nLWRR/EoAbasbw843NfKsIIf5LVBWU0hM0rMpZQYYznQ5BnViTs5qbwm72rSLOQ3Z2NgEBARgMBlyqC72i960ihBBCiArgcBdg1Fl8V5eqooO7cotcCCGEEGXiCez6vhZC/EedJbALYFAM2khdnSJ/jpQHCewKIYQQl86FBHYvBfltSgghhBBlMrhRdb+vhRBCCCGEEEJULAnuCiGEEKJMfurbnik3tWHKTW34qW9732IhhBBCCCGEEBVEgrtCCCGEKLNhTWsyrGlN39VCCCGEEEIIISqQBHeFEEIIIYQQQgghhBDiMiTBXSGEEEIIIYQQQgghhLgMSXBXCCGEEGWSYXPw8vKdvLx8Jxk2h2+xEEIIIYQQQogKIsFdIYQQQpTJrb+u5u01e3h7zR5u/XW1b7EQQgghhBBCiAoiwV0hhBBClMnSo6f8vhZCCCGEEEIIUbEkuCuEEEKIMhnSsLrf10IIIYQQQgghKpYEd4UQQghRJj/0bcd3fdryXZ+2/NC3nW+xEEIIIYQQQogKIsFdIYQQQpTZXU3iuKtJnO9qIYQQQgghhBAVSIK7QgghhBBCCCGEEEIIcRmS4K4QQgghhBBCCCGEEEJchiS4K4QQQogySc238+KyHby4bAep+XbfYiGEEEIIIYQQFUSCu0IIIYQok9t+W827a/fy7tq93Pbbat9iIYQQQgghhBAVRIK7QgghhCiTZYmn/L4WQgghhBBCCFGxJLgrhBBCiDK5o1ENv6+FEEIIIYQQQlQsCe4KIYQQoky+v7kdP/Ztz4992/P9ze18i4UQQgghhBBCVBAJ7gohhBCizIY0qs6QRtV9VwshhBBCCCGEqEAS3BVCCCGEEEIIIYQQQojLkAR3hRBCCCGEEEIIIYQQ4jIkwV0hhBBClElKno1nl27j2aXbSMmz+RYLIYQQQgghhKggEtwVQgghRJkM+n0NE9bvZ8L6/Qz6fY1vsRBCCCGEEEKICiLBXSGEEEKUyfKjp/y+FkIIIYQQQghRsSS4K4QQQogyuatJnN/XQgghhBBCCCEqlgR3hRBCCFEmU/u05ed+Hfi5Xwem9mnrWyyEEEIIIYQQooJIcFcIIYQQZTawYTUGNqzmu1oIIYQQQgghRAWS4K4QQgghhBBCCCGEEEJchiS4K4QQQgghhBBCCCGEEJchCe4KIYQQokySc/J5ask2nlqyjeScfN9iIYQQQgghhBAVRIK7QgghhCiTwbPXMnHDfiZu2M/g2Wt9i4UQQgghhBBCVBAJ7gohhBCiTFYmnvb7WgghhBBCCCFExZLgrhBCCCHKZGjTmn5fCyGEEEIIIYSoWBLcFUIIIUSZTLmpDb/078gv/Tsy5aY2vsVCCCGEEEIIISqIBHeFEEIIUWb961elf/2qvquFEEIIIYQQQlQgCe4KIYQQQgghLi3Fd4UQQgghhLgYEtwVQgghhBBCXDJO1YlbdfuuFkIIIYQQF0FRVVX1XSmEEEIIcb6O5+Tz3rp9ADzXrgGxVrNvFSGEAEBFJceZQ74rD71iwK26iTBFoFf0vlXFOZw4cYKwsDAsFotvkRBCCCH+RU6fPo3FYiEoKMi3qFxIcFcIIYQQZdJ52lJWHTsNQKdqUay88zrfKkIIIcpZcnIyYWFhBAQE+BYJIYQQ4l/k1KlTWCwWgoODfYvKhaRlEEIIIUSZrDme6ve1EEIIIYQQQoiKJcFdIYQQQpTJsKZxfl8LIYQQQgghhKhYkpZBCCGEEGU2+2ASAH3rVvEtEkIIUQEkLYMQQghxeajotAwS3BVCCCGEEEKIy4wEd4UQQojLQ0UHdyUtgxBCCCGEEEIIIYQQQlyGJLgrhBBCCCGEEEIIIYQQlyFJyyCEEEKIMjmalce4tXsBeL59Q2qEWH2rCCGElwJbJiB/hpRFSkoKISEhWCwWAOyOHEzGIN9qQpRJbt5JAq0xANjsWeh0BvQ6Izl5J3C7HRgNgTicuYQE1UCnM/hu7odKRlYCZlMwer1Z28btdqGqTtxuFw5nLgZDAG63E6slCp1OT3ZuEsGBVcjIikdRFIyGQAyGAHQ6A6rqwu1yYDaHkpEVj9FgxWwOIy//JKqqgupGBYICY9EpevJt6bhcdoKslbHZszCbgsnJPYHRaEVFxeHIwWQKQa8zo9PpUd0uXG47FnMY+QXpOJy52vXLbArBaAgkO+c4JlMwLrcDqyWSrJxjhIXUJCsnEVAAhSBrDDl5yaiqitkUisFgQVXdGA0BFNgysDtyURRQVZWQoGrY7NmYTcEU2LJQFAW9zgiKgtNpw+HMQVXdmIxB2B252B3ZRITWJTfvJEZjEEajlYyseKyWSOwOT3tVQCHQGoPNlonb7SQgIJL8glRU1YXqdoECRmMwFlMoOXnJgFL0PsHo9Sbsjhws5jDcbgcGvQWXy47TVUBObjJBgVUIsEQAKtm5SZhNYbhcdmz2DEKCqpGbn0JwYNWi8+0AICfvBABB1srk5J3EoDdjMYfhcjtwugpwOW3o9EYArJYorQ9l5RxDrzehU0xk5yYSZK2C05WPyRiEyRiEzZ5FfkEqFnM4DmceCgoqEBxYBZfbXtSHTxJkjSE14wCBAZVwugoAMBmDCvetM5KXfxqTMRCzKbTofQvfy9Nvc3KTsAZUIjs3CYPegk5vwqA3F/URBaPBis2eicUcTlb2Ua09qupGURRMxiAMhgD0OpPWnty8FEymIGz2HCymEHR6I05nASigoJCTdwKrpRIo4HbZCbBEkpN3giBrZbJyjgOF/Scn9wSKTqcd8+LHLiSoWmHdou+GAphMwZiMweTmnUCnNxNgDiM3P4XAgGgK7FnoFD1ut73ws+Qc074DQYFVUN0uCuwZuFy2omuBnqzsRKzWaOz2bJwuGyFB1Yq1wx+VzOwEVNWF2RSO0WhFp+ix2bMKvyfGQIwGz98YKlnZiSg6HXqdBUUBsylMOy82eyZGgxWHIxezORSbPQu7Pbvw56QpBLfLhlt1oyg6QoKq4XDmYTaFkG8r/AxmYzDBQcWPW0kulx2HM893tR8qmdlHcTmCMZlMBAWd/ee001VAkLWy72ovqurWrjkU9VkJ7gohhBCiTDp+t4Q1SakAdKgSyeq7r/etIoQQXmz2LDZsm+i7WpRBbHQbUlK343LZfIuEuGiRYQ1IzdgHgNEYhNORh4qbQGs0YKCgIBWjMYgCW+HvAecjOLA6dnsWDleeFuQDMOjNGE0hqG4nqurC7swtDDhCUTDqGFZLJKqq4lad2B15qKoTvc6E2RxKXv4pgoOqkZ17DFQICa4BqorLZcOtOskvSENRdIWBNRRUVCzmCApsaQQHFm0HhARWw+bIxuHIwa260OkMBFgiyM1LwWwKxWgIQlHA4czD6S7A6cgnwBKJy2XH5bbhctkJCa5OVnYiwUHVUN0qTlceBbZ0QoNrYHfk4HTZcDhyMRgsOJ0FWEyhKDojOr0RnaInO+cYZlMINnsWRqMVvc6Ew5GLy+3AoA/AbA5BwUBO3rHC0FxRhC44sBo2eyZ2RzYGvRWnK4+gwKooig6HIxcVFzZbZlE4rzAUZLVEYzRasNtzQCkMWtkdOQQHVkVVXThcBbjddhyOfFBUUNHardMZMOitKIqKzZ6tncsgaxVy8gon2zXqrThceQQGRJObn1JUQ0GvMxAYWBmHI5f8gjQCLFG43Q5s9kx0OkNR4Dobl9uFou25UJC1Ctl5SSiAUW/B4SooCohnFQXsw9DrDeTlnyY4uDq4VZyuAvILTqPTGXG7HVhMYRTYMwgOrEp27nGCg6pjt2fjctmKAr0qBoMVZ7EAnsUcToEtXft3gDmSfFsqBr1FCw7rdAZ0OiOKosfhyNGOdHBQdXJyj2MyhaK6nYCKAkXfAzdmUxg2ewYB5gjyC1JRdDp0GHCpDnQ6I6rqQqczotPpcTjysJjDsduzcauFN0HyCk4TYIlERaWgII0ga1VQICf3uNZegNCgGmTmHCXQGoteZ8DpsuF2O1BVFzZ7FgGWSNwuOzZHNoHWGHLzTqLXGTEYrLjdDhzOPALM4YCCy13YVwwGc+FnMAaTb0srPDaWwhsHiqIjKLAq2TmJXu3wJ8Acgcttx626teMeYIkiv+C0b1VCgqpjs2ejqk4ousHpYTIFYbfnaG0wGoIwm4JRFZXc3BMEmCNwK4XducCWhtkUXHgzxRyK2RhCRFgDqsV2KvZu/q3a+EbhDaRzCAupS0bWQd/VftWv3Z/oyOa+q0tITtnIoYR5WAMq0arpw5KWQQghhBBlsy658Jc439dCCCEurfMbOSnExdEp3v1LpzNoow8vRGHwS++7GhQdiqLDZAxEp+hRlDMhPU/fNhoDMRgsKF6hDAWdUri/4t8BnWIofC/FgF5v1kbQFqfXmwrrFt+uRPsUlKL9nyk3YNCbUZTCdpiMQaDotXbpdYWjTT11C9+/sE0Gg0XbzlOvOE9bPHUo/ISg/bvwM3h934s+VuE58Q7znGmvRfscxY+tp47RGITBYPE6lp62ex/vM33Bdz8exdumffZi6xSl8D96nbHYaMxin1kFRdGjlAjrFvL67MXW+V1f1A887fC8h9FY+L6ebQrrmEocv+J8+/qZ/nPmPCqKXgvuFldYR8VosKLXm9DrTF51PMfhzPsrXq+1sqJYol5vKjqQhd8Lz/+LH2+/x6NYW73OsadP6s+068z2hW3x7NtgsGIwWLTPr/gcY3yOle+1ozSe9p/5/hc/Bt7Odr7O9E///dZkDMKgN5foExXFc5zKk6e/GPSFT+9U7CcQQgghxBXv3mY1/b4WQgghhChv5x4nd6lcQEuKVb2ArcpVaYHY4vzW+acaLC7OxZyv8xh96svfFqrftRXHb389T4XbXtr2ViQJ7gohhBCiTL7sdTVzb+vM3Ns682Wvq32LhRBCCCGEEEJUEAnuCiGEEKLMbqoTy011Yn1XCyGEEEIIIYSoQBLcFUIIIYQQQgghhBBCiMuQBHeFEEIIIYQQ/2pvj/mVLm3HnPcyauQU8vPsTJu6ki5tx7BqxT7fXf4n7NubRPdr3i5xfM62eI7V22N+pfs1b7Nvb+GM95er/Dw7o0ZO+dd/Fk8fL95XV63YR5e2Y3h7zK9edYUQQojiJLgrhBBCiDI5kpHLyIWbGLlwE0cycn2LhRCizELCrETHhHotloDC2acNBn2JssjIIM8k4v9pRoOeKtUivI9NVLBWHhYeWOLYmc1nZlIXQgghxL/fZRXczc/P57333mPRokW+RZr9+/fTp08fTCYTiqKgKAp169Zl9+7dvlUviKqqvPfee9o+69evT0JCgm81vxwOByNGjNC2fe+991CLZiPctGkTgYGBKIrCsGHDfDctkzVr1mA2m7X3/emnn3yrlGrcuHHadhezjBs3zmt/w4YNQ1EUAgMD2bRpk1fZlWzatGkoisJNN91EQUGBb/F5yc/PZ+rUqTRv3hy9Xo+iKOj1epo3b87UqVPJz8/33cSvxMREnn76aapWraqdp4iICIYPH86RI0d8q5eQlZXFe++9R926dbXtg4KCGDJkCOvWrdP69IVQVZV33nnHb58p7mL7Y9euXcnJyfHa14YNG7BarSXq+lvi4uJITk722v5sMjMzueGGG0p9b18ul4t58+bRpUsX7Zp1MefWV/Hr1fl858rajuzsbLp06VLi+JW2zJkzx3cXmtzcXKZMmeLV300mE126dGHJkiW4XC7fTUrYt28f9957L0FBQdp71qxZk/fee4+srCzf6iWU9Xh4lEc7Vq9eTZ8+fbz2cb7fO1VVufPOO0sc/9KWs30Hz+WOOev4fOthPt96mDvmrPMtFkKIMhv1RA9mzXnSa7nn/msAuOHGpiXKxrx9mxb8/S+rXTeGb394yOvYfDV1BDGVQ7FYjEz46M4Sx+7qtrV9dyOEEEKIf7HLJri7Z88eWrVqxfPPP1/qH9ZbtmyhXbt2zJs3D4fDoa03m81ER0d71b1QiqIwatQo+vTpA8CBAwcYM2aM1/v4o6oqEydO5H//+x8AAwcO5Mknn0Sp4KEEqqoybdo07Ha7tm7KlCmlHjtR/lRV1W5EXHPNNVgsFt8q53Ts2DGuueYahg0bxo4dO3C73QC43W527NjBsGHDaNWqFXv27PHdVONyufjoo4+oWbMmH3zwAUlJZx5HS09P58svv6RBgwZMnTq11EDRkiVLqFOnDs8//zyHDh3S1ufm5jJ9+nTat2/PfffdR27uhY3Y++WXX3j55Zd9V1eo/fv3V8j3QFVVJkyYwOLFi32L/Dpx4gQ9evSgT58+rFy5UruWFD+3zZo1Y+vWrb6bntPKlSt55ZVXfFf7VR7tOHXqFAcPHvRdfcE2bNhAo0aNuPfee736u8PhYOXKlXTr1o0HHnig1H6mqiqTJ0+mcePGTJkyxateQkICzz//PHXq1GHjxo1e2xVXHsejPNqRmZnJkCFD6NSpE/PmzfPaR/Hv3e23305mZqbXth4ZGRns37/fd3WFWJ+c6ve1EEIIIYQQQoiKddkEd2fPns3evXt9V3v55ptvyMjIAGDQoEHs37+fpKQkZs+eTVhYmG/1CxYQEMCkSZOoXbvwbvaUKVOYPn26bzUvGzZs4J133gGgXr16jB8/HqOx4h91SklJ0YJMNWvWBGDp0qVnDUiUpnv37gwfPvyClubNm/vu5j8nNTWV9evXYzAY6NSpk2/xOWVmZnLPPfdoAaA2bdowf/58kpKSWL58Od27dwdg7969DBgwgGPHjvnsoTDI9P777/PEE0/gdrsJCQlh3LhxxMfHs3//fp588kmMRiMOh4N7773X76j49evXc+utt3L69GnwaceqVavo27cvFH0f+vXrV2qgqThPwHnQoEFaAO9smjdvXqKP+Vvuu+8+YmNjte169+5NYGCg1748o/hNJhN33HFHiX0UX26//XYCAgK8ti/NnDlzGDt2rO9qvzIzM7nrrru072jxY7pjxw4eeOABAA4dOsSgQYPO+ykBim4IjBgxwuvGTmnKqx1JSUla/zif60X16tV9d8H69eu58cYbSUxMBODuu+9m69atJCYm8s0331CjRg0o6mejR4/2eyPil19+4fHHH9f6+pdffkliYqLXZzl9+jR9+/b1G4wur+NR1nbk5+dz1113MXPmTADq1KnDjz/+SGJiIvHx8Xz22Wfa8Zg+fTrPPvus3xuN6enp2nWhdevWJc6D71KW6/YDLc6M8ir+Wggh/i3cbpUli3Zx58DJdGk7hmvbv85zT04j8aj3DSlPjtpvv17Oj9+t4rqOb3Jt+9cZ+8ZvFOQX/mxNPZ3N5A8X0qvbu3RpO4brOr7JmJdncirl3E9lXI7sNiezfl5H/5ve1z7vuLdnk5bq/YTS22N+5ba+E9myKZ4R9/yPLm3H0P+m99m88cwTYkcTTjP2jd+4ruObdGk7hu7XjuXjDxaQm2vz2hdFv8du25LAIyO+1up3aTuGoUM+ZcmiXX5/h4w/copHR3yj1X31xRmkpmb7VoNi7T2WmMpnkxdxQ5fC3MT9b3qfWT+vw2Yr+bPVZnMw6+d1DOz34Xm1x+12s2TRLoYO+VSrP7Dfh6XuXwghhLgYl01w91xyc3O1oE3VqlV5//33qVevHrGxsdSpUweDweC7yUXxPFKr0xUeulGjRrFlyxbfalA0Quuuu+4iIyMDk8nEV199RVxcnG+1CrF06VL27t2LwWBg1KhRWK1W7HY7P/74o2/Vcxo1ahRffPHFBS29evXy2se3336Lqqrk5ubSunVrr7Ir1datW9mzZw/169enYcOGvsXn9N1332lBpuHDh7Nq1Sp69uxJbGwsXbp0Yd68eQwfPhyKArxffvmlzx5g8+bNvP7661B0c2HTpk0899xzxMXFUa9ePT744AN+/PFHdDodbreb999/32tUa2pqKo8++qh20+S1117zakfHjh357bff+Oabb9DpdCxevNhvO4o7duwYAwcO1ALO56NXr14l+pi/5cYbb+TEiRNQdMx8R8kXFBRoNziaN2/O5MmTS+yj+PLuu++e142hhIQEnnnmmfP+PD/++GOp57Zp06Z88cUX/PDDDyiKwoEDBxg/frzfYKYvh8PBG2+8cc4bYR7l1Y5du3bhdDoxGAy8+uqrJY6j79KyZUuv7XNycnjuuefIyMhAp9Px008/8e2339KiRQuqVavGPffcw/Lly6lXrx4UXU927drltY/ExEReeOEF3G439erVY/v27dx///1Uq1ZN+yw//fQTOp2O5ORk3n33XZxOp9c+yuN4lEc7fv/9d+bOnQtAt27d2LRpE0OGDKFatWrExcXx4IMPsmXLFq677joA/ve//zFr1iyvfQAcPnyYkydPAvDkk0+WOA++i+91+0J83qM1CwZdw4JB1/B5j//GNV4IcXmZ8O5cxoyeidVq4qa+VxFbJZw1qw7w0ANfceRwim91pk9bw2efLKZ9x7p07Fyf6tUjsQSYOLD/BCPu/ZLpP6whNNTKTX2vomnz6iz+cydDb/+UPbuO++7qsmazOXnh6R/5+IMF1KgRRc/eLQgKtjD398088chU0tO9n6Y5lZLFc09OIysrn569WxBXsxKxVQp/l1qzaj/33/058+dupVbtStx8S2uqV49gxk9reWDoF5xILvx9k6LA7o/fr+bRB79h355krunakJtvaU2r1rVIiD/Fay/N4Ocf1xZ7Z1i2ZDfDbv+UbVsTaNmqJj17t2DL5ngefuDrEkF8D5vNycTx85k2dSWNmlSlZ+8W5OfZ+XDCfEY/N53cnDNp1dJSc3h61Pd8OGE+BQUOevZuwfU3NCHpeDqvvTSDCe/Ow+k8kz6qIN/O26//xmsvzSDpeDrX39CEnr1bUFDg4MMJ83n+qR+99i+EEEJcrCsmuKuqqvYHcq1atQgNDfWtUm769evHiBEjoOix12effbbEaEWHw8GYMWM4cOAAAG+++SadO3f2qlNRCgoK+O677wCoX78+t912G126dAHgr7/+0v7YFxVr6dKlqKpKt27diIqK8i0+K6fTyd9//w1AWFgYjz32WIkR30ajkRdffJGqVasCsHjxYrKzz4xMcDqdTJgwgby8PO3mQt26dYvtoVC/fv0YNGgQACtWrPDKT71q1So2bNgAQJ8+fXj++edLtENRFO68805uvfVWACZOnOg3h29WVhZvvPEGtWvX5tdfy3/G3/Xr1zNy5EhUVaVNmza88847JdqamZmpBT6bNm16XoHbc3E4HDz77LMcOHCAOnXqYLVafat4ycnJ0fJfV61alRdffLFEOxVFoWfPntqNkPnz55OSUvIPT1/Tp0/nf//7H2FhYee8kVSe7fD0kapVq1KrVi3f4nOaO3cuy5YtA+Cll15i0KBBJVLXxMXF8cYbb0DRdXfevHle5b///rs2CvbVV18t8fkVReHWW2/l3nvvBWDWrFleQfDyOh5lbUfx63dYWBgffvih359nERERvP3225hMhfkkZ86cWSJIvH37dlRVxWq1XtQNpgvVo1YMPWrF+K4WQoh/hbxcGxM+vov/fTuCF17ux/c/P0KvPi3JzMhj6eKSc3Pk5BTw6psDeGfC7bwz4XbuvrcL+Xl2Pv5gASknM7l3eFe+//kRXni5H5M+u4e33h1EXq6Nd978vUTA83KmqioBVhNTfniIj/5vGKPH9OenWaNo0TKOI4dT2LzB+3c+t1ulQcMqfDV1BKPH9OfDT4YSWyWclJOZfPzBAux2Jy+9egtfffcgz710M1999yAPjerOscRUJo7/QxvNeuTwKb6bsoJq1SP5/udHeH3sQJ576WY++r9h/N+X92OxGFn8506yswsHJaSczOSzTxZhMhl4f9LdTPrsHkaP6c/M2U/QqUsDUk/7H72bkZ7L7l3HmDh5qLbNz78/QavWtVi35iB//bkTAJfLzTf/+5ttWxPo07cVM2c/wegx/Xl97EB+/u1xWrSMY85vm5g7+8ygn99+2cif87fTqnUtfpn3FK+PHcjoMf2Z8dvj3NirOZs2HOarL/4ucaNYnFvFJji8EBfQkmJVL2CrcnU+fc1vnX+qweLiXMz5uoi0of62UPyurTh+++t5Ktz20ra3Il0xwd3iPJPwVBSDwcBbb71FmzZtoCio9umnn2odSy3KsztlyhS4hHl2PQ4ePMjq1asBaNu2LdWrV9dyBe/du5elS5f6bCHKW05ODqtWrQLghhtu8C0+p4KCAu0x95CQECIjI32rAFCpUiUtYHv06FGvCbwSExNZsWIFAAMGDKBDhw5aWXEGg0HrH6GhoaSnp2tl69admRhp+PDhpaYoMBqNDBkyBIDjx4+zY8cO3yqMGjWK1157TXt8vGHDhnz77bfnDIaej8zMTF566SVtlPz777/v95gdPXpUmyCtffv2Zf5Oer7rM2bMICwsjAkTJpwzkJ+bm0vz5s2pU6cOnTp18puiACA8PJzGjRtDUaD+XKOC169fz6hRowB48cUXueaawklmSlNe7cjOzmbfvn0ANGnSxO9xPxun08nvv/8ORUHV++67r9Tz0rZtWypXrkx4eLjXTaqCggLmz58PQOXKlenYsWOxrc4wGAwMGTIERVHIyMjQbqBQTsejPNpx+vRptm3bBkWjds8WlG3YsKGWSuH06dNekzaqqqpNpFezZs0SQWYhhPiv6X3zVbQpNlGYwaDnxp7NURSF5ONnfvfxiKsZxdVtvNPMbN92lG1bEmjctBqDb2+PwaDXyq65rhG33NrGb8Dzcnf7XZ2oVfvM/CWBQRa692wGwIkTZ0bbenTv2YzAIO+5JlYs28exxDSuu6EJN/YqPO54bnoOakvb9nXZsO4QB/YXPoG1e9cxFBT69LuKyrHeN+NjKocSGmYlPT0Xu63wxubWLQkcS0wrcZ7NZiND7+tCpUohxfbgbeCQ9l6TyIWEBDD84esxGvUs/nMn+Xl2jiWmsmTRLqpVj+De4ddiNp+5ARwRGcTIUTdgNOqZ9/tmsrPzyczIY97sLVgsRh5+vDshIWd+h7YEmHjgweuoVCmExX/tJDmp5DEUQgghLsS/Prg7btw4FEXhhRde0Nb17dsXpWgm+4ULFxIYGEhwcLA28mvZsmUEBwdrdS5ktvvzFRkZyeTJk7WRf2PGjGHlypXgM5nRpcyz6zFz5kwyMjJQFIXBgwejKArdunUjIiIC/qGJ1YYNG4aiKAQGBmoBB1/5+flMnTqVpk2bohTN3h4eHs4jjzzCsWPHSE5OJi4uDuUss7qrqsq6desYMmQIERER2n5iYmIYOXKkFoTyx9PGrl27kpOTQ3JyMk8//TQxMTHafmrWrMlrr712zpnuPfk1K1euTNOmTX2Lz8lgMGi5Yu12e6nnq6CggFOnTgEQHR3tNWnb3r17OX688NHAPn36nDU1yZ133omqqiQlJXkFoz35Oq1WqzZCuDRxcXFaoLZ4UNiX0Whk3LhxbN68mSZNmvgWX5Qvv/xSe5z+oYceKnWUvGcyNYPBUC7vXTyn9jvvvEO7du18q5QQExPDxx9/zMGDB5k+fXqp5yU9PV0bRW0wGLRUMP4UD2736dOHUaNGaX80laa82lF8MrWWLVte8MSBntzUFAXcSwuqAtSuXZvk5GTS0tL44IMPtPXp6ens3Fk4sqZVq1ZUqVKl2FbeGjZsqJX//fff2mjX8jge5dGO9PR0IiIiCA8PJzAwsNR2UJS7Oi8vz3c1+Eym1qhRo3IZpS6EEJezKlXDS/xsDA6xYDYbyC9w4HJ537yMrRJGgLXw6QiPbVsTUFWVzl0alAheKopCq6sLn17ZuP6wV9nlrmq1cN9VRFUKBiiRK1dRFKpWK/ybw8PlcrNlczwA13Vrgl7v/buE2Wyk5VVxOBwudu0o/N2zT99W/LH4ee4c2pmc7AIOH0rhzwXbGf/OHEbe/xUnT3g/NbllY+H+W11dq8R5jowKpm59/0+WmM1GOnQqTPtUXFzNKOrUjeFowmnS03NJTsogKyufRo2rUim6ZKC4dp1oGjWpytGEVE6eyOTUqSxSTmZRv2Es1WuUvPFdOTaMlq3iSEvN4WhC4YAOcf4ufqxeebuAlhSregFblSvf74Y/fuv8Uw0WF+diztdFjID1t4Xqd23F8dtfz1Phtpe2vRVJ+8m6a9cuKlWqpAWxSlvq1aunTXjzX9e2bVtefvllKArAPf3002zbtk2bzOhS59mlaPTWjBkzoOgP+quvvhqKgsw33ngjlGFitYq0Z88eWrVqxbBhw7xyaWZkZPDpp5/SoEED5syZ47WNr9zcXO677z7at2/P9OnTvUagpqSk8Pnnn9OwYUPeeustv5MPFTd37lzq1q3LBx984PXYdUJCAm+88cY5Z7pfvnw5aWlp5wzwlMZisWj5L0+cOMH06dP9PnKwYMECLdDUqVMnrwCOJxd0cHDwRQWYyyI+Pr5Ee4ODg3nuuec4ffo0zz33XKmjgC/UwYMHef/99wGoW7cuTz/9dKkXeU/QOTw8nM2bN9OzZ88SNwEeeeQRv2klfBXPR3zPPfdw//33+1a5aKqqsmDBAu1GSK9evYiOPjNipjhVVZkwYQKLFy+mdu3aTJo0qdyO7fm0o3heV4PBwMiRI7WbMIqiYDKZ6NmzJ0uWLMHlOpOHzqP4aOqOHTueNZhZmuITusXGxp41wBwaGqqljjh58mSpN078OdfxKI92NGvWjO3bt5OWlsa3337rs5W3AwcOaIF130DwiRMnOHy4MLgQERHB6NGjqVevnvZki16vp127dsycOfOc18PzcSgjl+ELNjF8wSYOZVw5jyMLIa4c1ap7BxyLy8zI00aAeoSGWjGZvH8mnTpZeHN//bpDvDd2Toll4R/bUBSFY8fSyM8798SmlwOLxUh4hPfktMV5jomH2WwgOMT755/d5iQzIw9FUfhjzpYSx+29sXO04O/BopG7AFs3x3PXoMn06vYuw27/lDdf/YXZvxb+DNbpvH/Xczpd6PU6wsNLttVkMhAa6v9JsZCQAKKiCgPVxRn0eiwBJvJybWRn52vB5Fq1o/3+nmk2GwmPCMJmc5CXaycjPY/8fDsxMaEEBHjfJKAoqFC5KBdxVtb5/y4ihBBC+KMFdxs3bszIkSO9S/14+umnzzqyqrw98sgjJCUlaUFUiibTSUpKYsOGDVx77bUcPnyYAwcO0L59eyga/XXgwAGtjm8wojw99thjDBw4EIpG8V1//fVa/sRLmWfXY+PGjezZsweK0kF4HhE3GAzcfvvtUBSInjZtWong2z8lISGBfv36acete/fuLF++nKSkJJYvX0737t3Jy8tj5MiRHD161HdzKEqDMHDgQC0VRo0aNfjss8+Ij48vMbv8K6+8wsSJE0v9/GvXruXuu+/G4XAwatQoduzYQVJSEr/99ps2EdTp06d55plnvNIgeDidTv766y8oGmV+tgDP2dx9991069YNivKQvvjii1oQLTMzk3fffZf77rsPioL3vqk/PKOUw8PDiY6O1kZGN2/eXAvwBAUFce+995Y6otmT6/Nso4c9UlJStFGEx48fJzfXO8AzefJkxo0bR0hIydEOF0tVVSZPnqwFB892fSo+6eKpU6d4/PHHWbhwYYmbAJ6bCcVTrfhyOBy8+OKLbNiwgXr16jFmzJhyGZ3vcDjYtm0bw4YN44477kBVVerVq8ezzz7r9w8JgF9++YWxY8ei0+l47733qFmzpm+VC3Yh7fDkdQV44403+Pzzz72+pw6Hg4ULF9KtWzfuvPPOEvnJT5w4ofWtevXqaaPv+/TpQ1BQkBaI7Ny5M/PmzfMbIC7e9zyTrpXGsz/8pDIpzfkej4puR3EOh4MvvvgCu70weOB7rYmPjyctLQ2KJlx77733OHjwoJZGwu12s379egYOHMi1116rjdK/WHfMXsuX2w7z5bbD3DHbe4IbIYS40mzdHM+c3zaVWFYs21vq7w6i8Pe21Sv3lzhuc37bxIZ1h7zqblh3iMcf/pbjx9K5+ZbWjP/wTn6Z9xRLV7/CZ1/d73f07MXQ6ZUSgeKyMBj0GE1n0nWcj+IpHoQQQoiLoQV3FUVhxIgRfidc8mjTpo0WyLxUgoKCiI2NJSgoSFsXHh5ObGys9hh6TEwMlStXxmw2A2A2m6lcubJWx/MHdEUwGo2MHz9e+0Pe88f0pc6zS1FQ8dtvv0VVVUwmEz169PAqb9eunZa/8bfffiM+vvAO+bl40mCc7zJs2DDfXZRKLcpZ6pl4buzYscyfP58uXboQGxtLly5dmD9/PmPHjj3rL8vff/+9luvyuuuuY8uWLTz44IPExcX5nV3+lVde0dJo+LLZbAQFBfH333/z8ccf07RpU2JjY+nXrx9///23NjndqlWr/I6APnnyJBs3bsRqtWojpy9GaGgov/zyC6NGjUKn0zFu3DgqV66MoiiEhYXx4osv4nA4GDhwIMuXL/caIV48Z290dDQnTpzgmmuuYdiwYezYsUML8OTm5jJlyhQaN27M5MmTSxxjT85Qp9PJzJkzS5R7OJ1OZs2a5bu6wu3atUsb3dikSRMGDBjgW0WTlZWl9TOKchm/+eabWvB+1apV3H333VAUOHvkkUcYP3683888e/ZsvvrqK3Q6HRMmTCjz6Hyn08ltt92GyWSiZcuW2oRa/s5tcQkJCbz44ou43W5GjBhBv379fKtckItphyc/rMfdd9/NX3/9RVJSEvv37+ftt9/WAvrTp0/n1ltv9QrwJiUlQdENKLPZzNNPP0379u2ZN2+edoPA7XazatUq+vTpw5AhQ0oEiH0nEjubwMDAc6YY8bjQ41FR7fBn9uzZTJ06FYr6fv/+/b3Kt2/f7vXv7t27M3v2bBITE0vc8FqzZg3XX389CQkJXttciE0nz9wkKf5aCCGuJJViCn+evTVuMCvWjyl1mfTZPSVSOvyXmcwGQsOsmM1GvpgyvMTxKr6MHtMfl8vN779uwu1Wee6lm3nupZtp37EelSqFeOU5Ls5g0ONyuf1OZqeqaom0Gx6ZGXmc9jPZmt3uJCM9F2ugmeDgAGIqFw54OHI4xe/vhnl5Nk4mZ6DX6zAa9YSFWwkIMHHyZCb5+SVHcbtcbo4mpELRaGchhBCiLLwSHlWvXp2nn366+CqNUpT39kIny/kviIuL49lnn9X+rdfrefDBB8tlJN+FOHDgAH/++ScUBTg9o0w9YmJitODX8ePHS8w2/0+Ij49n5syZAFx77bWMGjWqRDBer9fzxBNPaJN++Tp9+jSTJk2Cokehv/jiCy2/cHERERGMHz8eq9WK3W7n66+/9vvLGcDjjz/udzKk0NBQ7rzzTigK5BQf9emxY8cOjh49SuPGjald23sijguVlZVVYgRscTqdDkVRSgSVnE6ntt3x48fp2bMnGzdupE2bNsyfP5+kpCR27NjBqFGjMBqNuN1uRo0a5ZXHlKJ+5LkhMGnSJH7++ecSx0xVVWbNmsU333zjtb6iqarKl19+SUZG4SQUjz766FlH6WdnZxMZGUlgYCADBgzgyJEjvPzyy1rwvmPHjkydOpW5c+dquYPHjx+vjfb1OHjwIKNGjcLtdvPSSy9x8803e5VfjPz8fK8Jwjzmzp3Lhx9+6LcPOBwOnn32WQ4cOECbNm146623LiqlQXEX2o6CggIURaFSpUrExcWxfv16pk6dyg033EBsbCz16tXjpZdeYteuXdqNjsWLF/Pjjz9q+/CcP4AnnniCiRMnEhISwuTJk/2OvJ85cyYDBw684JGuHoqinPdxutDjcSEupB2+Fi1axAMPPIDb7Uan0zFx4sQSfT83N5fY2FhCQkL4/fffWbhwITfffDPVqlXTbnht376d2267DYp+fpztiYZzGd7izLWu+GshhLiSNG5SDYDNG49c9PXyv0iv19GocRVsNge7dxbOB3E2njQOFouR2nVL/m538MBJUnzSQXgmRFu6eFeJQG56ei779/mfg6WgwMHO7SWfXklMTOX4sTQaNIwlqlIwsVXCCAkJYM/u45xKKTn3RkL8aQ4dPEl0TAiVKhUu0TEh7N+bTOLRwiBucSknM9m94xghIQHEFqVnEEIIIS5WiRl6Bg4cSJs2bXxXc9NNN2k5QIW3hIQExo8fr/3b5XLx+uuvlxhdVtEWL16sjRzu37+/37ybffr0wWQqHEnw/fffn1cbu3fvzvDhw897ueaaa3x3UaotW7Zok37dfvvtXiO0iwsICGDIkCG+q6Fo4jDPxEE9evSgTp06vlU0DRo00Pr38uXLvfLpeiiKctZ0GsVz6PoG/gD++OMPVFWlQ4cOhIeXnIDifC1atIhmzZrx9ddf43a7ufvuu1m1ahVJSUmsW7eOESNGAPDzzz/ToEGDUoP1ycnJpKSk8Nprr7Fq1Sp69uxJbGwsTZs25eOPP+bvv//WcvW+//77Wg5Pim4IvPXWW+h0OtxuN0OGDGHYsGGsXr2a5ORkVq9ezbBhwxgyZAhRUVHExsYWe+eKVfzGQMOGDUuMXPRVv359Nm/eTE5ODrNmzfJ7AwCgd+/ePPXUU1B04+CHH37QyjIzMxk5ciTJycl069aNZ555BqUcRucbDAbGjBlTIpiZn5/PuHHj6Nevn9d31TPifcaMGYSFhTF58uRyufF2oe2wWCxMmTKFlJQU4uPj/f7sAKhWrRoffPCBdu356quvStwYcTqd7Nmzhy5durBv3z4eeeSRUkfe//XXX/z2229e21eECz0el8K8efPo168fGRkZ6HQ6vvrqK69JED3eeOMNkpKSyMzM1J6+8BUaGso777yjfW+nT59+Xvmm/fm/G1vx1+Br+GvwNfzfja18i4UQ4orQrEV1GjWuym+zNrBk0S6vAK/N5mDcW7O5tv3rfPO/v722E3DtdY2JjArmy8+XsHO79/wtWVn5PP7Qt1zX8U0WLdyhjfQtKHCwfu0hr+N8LDGNDyfMLxFcb9WmFrVqR7N00S7+XrJbK3c6XXz52VIS4kuftOznH9eQEF84QTFF7fnfp0twOt3c1LcVJpOBatUjuf6GJhxLTOOb/y3DZjuTrz4tNYfPJi3C4XBxU9+rCA2zEhpm5aa+V1FQ4ODTj/7yyqtbkG/ny8+XcupUFtff0IRq1cv+O5wQQoj/thLB3cjISF544QWvPwStVisvv/yy32Dhf11+fj6PPvooBw4cQFEUAgMLk/ivWLGCF154ocSIyoqSmZnJ999/D0DVqlW1ydN8tWzZUguQrF+/njVr1vhWKWHUqFF88cUX571cyMRSnjy7AQEBtGp19oBAw4YNtRGVxe3atUs7ztWqVePEiRMkJyf7XbKzs6lUqRIU5fr0l2cyICDgooOy6enprFmzBkVR6Nu3r2/xeUtMTOShhx4iIyODsLAw/vrrL6ZOnUrHjh2JjY2lbdu2fP7556xYsYKwsDDy8vIYPny4V2C2uGuvvZZnnnnG72jyjh078uabb0JRINg3vcKAAQP45ptvtG2/++47OnXqRJUqVejUqRPfffcdDRs2ZM6cOdSvX99r24q0evVq7cbA4MGDiYnxPwvyhVIUhTvuuEML/q5atYqcnBxUVeXzzz9n8eLFhIWFMXbsWC0ncVkFBATQrVu3UoOZixcvZuLEiVr9DRs28M477wDw4osvlhpUvVAX2o4L0aZNGy0IuXv3bm2yr+KsVisTJ06kcuXKvkVERETw8ccfazcivvzyy4savVt8VPu5VOTxuJB2UBTQnzJlCn379iUvL08L7A4bNsxv4PZ81a1bVxu9e+LECa8JLS/UDTVjuKFm+XwPhRDi3yg01MpTz99EcEgAY0bP5K5Bn/DuW7/z2kszuKX3+8ydvZm4mpXo1efM03PTpq6kS9sxvD3mV6997dubRPdr3ua2vhNJLZYWoLT1l7tq1SN48tneFOQ7eOiBrxg+7AveGzuHF5/9iX49J7B50xFat6lFh0710Ot19OzdAp1O4YtPF3P/3Z/z3tg5jBo5hdtv/ZjIqCCaNK3mlVIhPDyQF1/phzXQzJjRM7n/7s95963fuWvQJ8ybvRmLpeTvwBT93peWlsvQIZ/y4jM/8tpLMxhw0wds3nSEfv1b06lL4e+2er2Oe4d3pUXLOObO3sxtfT/k7TG/8tpLMxh0y0ds25rAjb2ac8uAMynZbhlwNTf2as7mTUcYcNMHvPbSDN4e8ysDb/mIP+dvp0XLOO4d3hW9vsSf5CUs+nMnt948sdTlt1/OTPT89phf6dJ2DNOmeqegu9D1QgghLh9+f5L06tWLm266Sfv38OHDadu2rVcdUfjH9qRJk5g7dy4At912G4sWLdKCD1988QW///67z1YVY8OGDaxfvx6KHsOvXbs2ip98uFarlYULF0JR+7/99ttLFoD2xzORl6Io6HR+u6OmSpUq2gRxxRV/rPutt96iSpUqZ108oz3z8vL8jtwti8OHD7N7925q1KhBo0aNvMo2bdpEYGBgiXPiWQIDA9m0qXAG4D/++EML1L7wwgt+R+ZRFJj1pKRITk7WAvwGg0G70cA5RkUDdOvWTQtmFp8gi6JzM3ToUA4fPszDDz/s9fh3kyZN+Pbbb9m8eTO1a9fm1KnCUQ9Vq1b1ev/yVlBQoI2o9ZdfuqwqVaqkBRiPHDlCdnY2K1eu5JVXXgHgnXfeqfBrom8wc8aMGZw+fZrU1FQeffRRMjIy6NOnD6NGjSpTcO9cSmvHhbJYLNqo+ry8PC3Xrme/AF27dqVJkybav33VrVtXS5dy8OBB7bt/IekNiuejvhhnOx4V1Q6Hw8Hrr7/Ovffei9vtxmg08tNPP5U5sOtR/Jj7expBCCHEGQ0bVeGb70fSp18rUlKymDd7C0sW7SIoyMJDo7rzf1/dT+VYeczen2uva8Q300bSoVN9Dh44yZzfNrFy2V6q14hk9Jj+vP3eYAKDCicH7dSlPhMnD6VW7WgO7D/BnN82kZaaw+gx/Xn/47to1aYWBQUOrxG5jZpU5ctvR3Dt9Y05dPAk82ZvQa/X8cbYgVx7feNiLTnDbDYw7v3buXVQO9auPsiSRbuIqRzK62MH8uRzvb1y/EZEBvH+pLt44pleWCxGFvyxjSWLdlGlajivjx3I6NduwRJwJteyJcDE6Ndu4fWxA6lSNZwli3ax4I9tWCxGnnimF+9PuouIyNJ/Py/O6XSRcjKz1CU3p8B3EyGEEP8hfqNpAQEBjBkzBqvVSmxsLI8++mi5/AF5pSke7ImNjWXs2LG0b99eG1Hndrt57rnnznvisoulqiq//fZbiceTzseCBQu00bP/ReUd2F6zZg15eXlcffXVZRpJunr1aij6Ll5//fW+xV66du2qTcy0cuVKcnNzsVgsXoHw4qkk/AkLC9OCv8ePH/c7orBatWp88sknnDx5ElVVUVWVnTt3MnToUAICAjh16hQnTpyAoicAKvKaceTIEdauXQtAhw4daNasmW+VMrFYLNoIb48vv/wSu71wQoyHHnqoRHBeURSqVKnC0aNHAVi2bBnBwcEoikLXrl0vapRprVq1aN26NRSN5j5+/DirV69mw4YNUJT71Wq1lmiHoijaZFue/qj43EC4EP7acTH8TSBWvG9GRUVhsRT+UedP8X6dnp6u5cONjo7WRvUXnzTPH4fDofXvGjVqnPWmR2lKOx4V0Y60tDSGDBnC66+/DkXHaNmyZQwcOLDcvmPnuj4IIcS/1Z1DO2uTcJXGU6dTlwa+RTRoWIW/lo/2mgDNs+5s+6wUHcLzo/vy17KXtInAZvz+BHfc3YnAwMIJnj1Ka6PnfWbOfpLIqOBzri9PkVHBzJz9JH8tH02DhqX/DBg9pn+pdTp1aVDic52tvkfNWpV4b+IdLF39inbspv70MD17t8BsPjO6VlEUWl1di6k/PazVmzbjUa3eiIe6sWL9GG7s2dxr/1WrRfDWu4NYtvY1bZtrr2/My2dpW2CQmcee6qm1adqMR7n+hiZ+B56YzUZuHdSOGb8/4dX+0urrdDquv6GJ1+eY8fsT3Dqondfnpej4+fZVz3E+13Ln0DMp5Tz7Kb7uYtYLIYS4fJT8CVSkVatWDB8+nMcff5y6dev6Fv/nJSQkcP/992O329HpdEyaNEk7Tvfffz8DBw6EopGcI0aMuKjAzvmKj4/X8k/WqFGDb775hhkzZpx18TxWnJGRoY1k/SdUq1Y4MYXdbic//0wuKn+SkpL8jnQrnmpg9uzZWuDxfJbymAzLo6CggPnz50NRnuILGcVXmoiIiHPmsQ0ODqZ69epQFKz2BPlbtGjhU/P8GAyGiwoaHThwQMv53K5dO9/icrV582btvbp27VpqYMxXTk4OJ0+exOE4kyfNn+zsbC33aHR09FkDjhXJarVqNwlcLle534w4X2drh8Ph4OTJk+d1jfOMCjUYDFrqkxo1ahAcfOF/vOr1eu07Vq1aNW2kdXJyMgUFpY9eSU9P19KxnC2oejalHY/ybsexY8fo0aMHv/zyCxSNsF2zZg0dOnTwrVqCy+UiJSXF68mG0hQfreu5Jl+o/WnZ3PvHBu79YwP7066cx4iFEEII8e9z4X+pVJQLaEmxqhewVbk6n8Fgfuv8Uw0WF+diztdF/P3vbwvF79qK47e/nqfCbS9teytSqcFdRVF48803eeyxx3yL/vOKz1IPMGLECPr166eVG41Gxo8fT7169aBo8p9PP/20TB3vbIrnHr355psZNmwYt91221mX5557TgvgTZ8+3e+M8JdC48aFj0g5nc5z5no8fvw4eXl5vqtp3ry59lk8ozn/CUlJSWzevJmIiAi/k8q1bt2a3NzcEgFmz5Kbm6uNBvRIS0sjOdn/7L4e2dnZJCYWTkxRPDDbtm1bLfh1ruNSPHBes2ZNLaXC6tWrqVevHiaTiWnTpvlsdYaqqvz6a2EuuYiIiHPmTy6rdevWQdHnLS1lha8HHniA4OBgYmNjWbx4sW+xl0OHDmlpA5o2bUpYWBgfffQRSUlJZ122bt2qjYJs3749Bw4cICkpiVmzZmnHdN26dTRs2JCIiAjGjRvn9b6+MjIytMkCw8PDiYmJoXv37iXe19/iyaMaEBDAggULSEpK4vDhwzRvXjjCpaztAJg/fz4mk4nKlSvzwgsvnPUal56ergURq1atSq1ataAosNmgQeEIld27d5eYaK24nJwcEhISwCd1RmhoKA0bNoSiHNypqSVnpfbYv3+/dr1r3bq19n0pj+NRHu3wSEhIoGfPnmzcWJhDb8CAASxfvvy8brZu27ZNa9eQIUPOeuPM6XRqI8GtVqvW/gt155x1TNkRz5Qd8dw5p/D7KYQQQgghhBCi4pUa3KVoRKBMouZNLTZLPUWTBL311lslRmnGxcXxzjvvaI/nvPLKK6xcWf5J6vPz8/nuu++gKNA1ZMiQEkECfzp06KDlDN27dy9Lly71rXJJdOzYUXtU+9dffy01COFwOLR8sr6aNWumBYd++eWXswaqU1NTadu2LREREbRv314LipaHnTt3cuLECZo1a6aNpL1Ynryi+fn52mjk0vz9999acL9du3ZaELFZs2ZasPjnn38u9bOqRWk9PIHz4sHSypUrk5OTg8PhYO7cuaWOHN29e7eWe/rGG2/UbmxUhOzsbLZs2QJATEwMNWvW9K3il2e0uqqqzJw5s9TP4nA4+Oqrr7TyHj16oCgKYWFhxMbGnnWJjo7WrgVms5nKlSsTGxvrlaYiJCSEU6dOkZ6ezqxZs84aANy4cSObN2+GopHYnrQFvu/rb/GkB1AUhaioKGJjY4mJidFGupe1HQANGjTQAqwLFy70O0Ghx59//qmlhOjQoYMWEA0PD9dyvG/atOms16ItW7Zok0C2bduWyMjC2aUtFos2gWFCQkKp+3A6ncycORNVVQkLC/OaeLI8jkd5tIOiCTLvv/9+7YbX/fffz/fff6/lxT6X4gHzFStWnDWP7qZNm7QnDpo1a3ZewWN/tpw8M0K4+GshhBBCiPJW+l9Gl9oFtKRY1QvYqlydT5zAb51/qsHi4lzM+TpLvKE0/rZQ/a6tOH7763kq3PbStrcinTW4+29SfNKdf1LxPLthYWFMnjxZCzD46tevHyNGjICitAMjRow4a/DjYmzdulULILRu3fq8c4+GhobSv/+ZHFlTpkwpNbBakWrWrKmNMFy4cCEffvghLpfLq46qqkybNo1Zs2Z5rfeIiYlh8ODBUBSofuaZZ/zmjHW5XHzxxRds2LCB9PR0atWqRew5Uh6cr+IjVzt16uT3EesL0bt3by3I8sEHH/DHH3/4VoGi8//SSy9BUX/0HEuKzvE999wDRZNPvfTSS36Py+LFi5k8eTIU3azo3r27VlajRg26du0KRYHz5cuXa2UeJ06c4KGHHuL06dOYTCYefvjhEjc7ytOpU6e0yeYaNmyoPd5/Ltdcc412TL/55htmzZpVImjuuXkzZcoUKEqv0adPH686ZVW3bl26desGRRMhfvHFFyX6PEXn7JFHHsHtdqPT6XjiiSfKNT1EebSjRo0aWmD24MGDvP76636vI+vXr2fkyJFaMPO5557z6iMDBw4kKioKVVV59tln2bNnj9f2FPWz0aNHY7fb/faz4t+Z1157rcQ+VFVl1qxZfPPNNwD07NnTa5RqeRwPyqEdqqoyYcIEbXT58OHD+b//+78LutEaHh6uXQvy8vJ49tlntTQmxSUkJHD33XeTl5eHTqfjpZdeIjQ01LfaeXnwqtp+XwshhBD/VueTI1gIIYS4HFw2wd3iE758//33HDhwgJSUFL9/fFeU4nl2AV5++WVt9Ks/BoOBt956izZt2kBR4PGNN94oNd/n8uXLGTFixHktX331FRRNqORpT//+/S/oD/ObbrpJC5ovXbqUrVu3+lYBYNKkSSXe/1zLCy+8cF65HhVF4cknn9RGer700kv06tWLFStWkJyczOrVq7nlllu0WeJL8+STT2qBme+//57GjRvz+eefk5CQwLFjx5gzZw4dOnTQAqFRUVG8/PLL5RaETE1NZf369SiKoo0QLYvq1avz7rvvotPpyMvLo0+fPgwdOpRFixaRnJzMzp07eeyxx2jbtq02IvfFF1+kZcuWXvu5//77GT58OPg5Ljt37mTUqFH06NGDjIwMdDodb7/9ttfNCoPBwJNPPklYWBh2u52ePXvy2muvceDAARISEvjkk09o0KABK1asAODNN9+kc+eKnYwhNTWVzMxM8EkhcS7Fj6nb7WbIkCEMGzaM1atXk5yczIIFC+jRowfPP/88FPWRsWPHljlQ78toNPL6669rI149fX7BggXauR09ejSNGzfWUr+88847551+4nyVRzsMBgOvvPKK9v396quv6NixIzNmzODYsWNaP+3cubN2PXj55ZdL9NMmTZrw2WefodPpOHz4MC1atOCxxx5j586dJCQk8Pnnn3PVVVdp/eyhhx4q0c+Kn9/Dhw/Tvn17PvnkE62vjxgxgiFDhuB2u6lXrx7vvvuu1/e/PI4H5dCOXbt2aTdbFEUhOTmZRx55pMQ11nfxveaOHDlSuyYuXbqUq666SvvuHzhwgLFjx9K8eXOv9EK9evXStr9Qn3RvxZIh17JkyLV80r1i07IIIYQQQgghhDhDUX2Hrv1LHTlyhC5dunjN0h4cHMyKFSto0aIFOTk59OnTh2XLlnHttdcyd+7ccg3KOBwO7rzzTi0dw8CBA5k2bZrXZF6lWbduHddff7322Pt3333HXXfdBUWPxF5zzTV+c8mezdChQ3nvvffo2rUre/fuJSwsjBUrVtC0aVPfqqVyOp3cfffd/PTTTwA88sgjTJo0CUVRGDduHC+88ILvJuetRo0arF27VhsZO2zYMKZOnYrVamX58uUlcsvu2bOHAQMGsHfvXq/1HlFRUYwYMYKxY8cC8O6772pBOI+0tDRGjhypnaPS1KlTh5kzZ5YIMJ2rjR5z5szRHr32tGPlypVcd9111K9fn2XLlmmPaZeFqqrMmzePO++8k6ysLN9ijdFo5MMPP+Shhx7y+1hCbm4ujz76qDYa1R+r1crPP/+sjcL0NW/ePAYNGlRqP9XpdHzwwQc8+uij6PV63+JSFe///s6pP/6O//lSVZXvvvuOBx54oNSbLBSNCP7ll19o1KiRb9FZJScn0759e44ePXrO69C5+jxF53bChAk88sgjF3RcuYD+XB7tON99fPnll9x9991+++n5npunn36ad955x++1V1VVPvnkEx5//PFSbwZFRUUxf/58rr76at8iuIDPcrbjUZZ2jBkzhtdff91r3fnwveZSNNr5rrvuOmeO6TfffJPnn3/e7zEVQoiKYrNnsWHbRN/Vogxio9twOn03DkfJJ7WEuFiRYQ1IzdgHgNkUht2ehYqbQGs0Op2F/PwUrAExZOUUzotwPkKDa2G3Z2JzZON2n/m9z2AIwGQKwaAz4nDkYnNm43YVpkoLC6lJRlY8ocE1cLmdOB152Bw5qKoTvc5MgCWcnLwThIbUJDM7HlQIC6kNqLicNlRFJSf3BIqiQ1VdKCioqARaK5Obd4LQ4KLtgLDgmtgcWdhsmbhVFzqdEWtAFDm5yZhNoZhNYSgKuFwF2Jw5OOy5hAbHkW/LxO0qwOkqIDy0NumZhwkNqYnqVnG5bOTmnyA8pDYOdwF2ezZ2ezZmUzA2ezYWUyiKzohOb8SgN5OZFY/FHEaBLQOj0YpBZ8HuzMXlsmHQWzGbg9HrLGTlxHtNChUaXJMCWzo2eyYGvRWnK4+Q4DgURYfbZcfhyqegIK3oOBT+nmq1RGM2B+FyOUFxYbflUGDPJDQ4DlV14XQ5cDpzsTtyQVFBBZMxCLsjB73eiF4XgKKo2OxnJtUNCapBVs5RACymcArs6QQHViE7t3BeE0VR0OmMhARVw+Wyk5VzjABLFKrqosCWjk4xYDaHYrdn4nQ7S0x7FRJUg8ycoyiAUW/B4SogLKQmNnsW+QVpmE1h6PUG8vJPExpSC1QVl9tBTu5x9HozLpdNa09ocByZ2QmEhtTC4cgtXJx5gIrBYMXpPPM3cEhQdbJyzqQ7DLLGkpOXjMkYjN1R+Pn1ejN6vRlQsNsz8Tz8HxpSm6zsIwRaq+ByFYCqoqou7M5s3G43QdYq5OQlYbVEkZd/CkWnR68Ycbpt6HSmwr6uN6Og4HDmEWiNIb8gDbfboX2GkOAauFwOcvOSCQmKAwWysr2/m+EhdUjPOkSgNRajwYLTZUN1u3C5bRTYMggKjMXpyKfAnkFwUFWyc46j15kwGgsHVRXY0gkOrIqKistVQH5BGkZjIKrqxmwKJTfvBIDWJkXRERIUR2Z24YTlZxMSVB2HMw+Xy47DmYOqgjWgEnn5Kb5VCQuphcOZh9OZj6q6sTvOTPDt+e4UbnsKoyEIsykYRacnO+cYoUE1cOFEdbvJzTuBxRxKgS0TszkUszGEiLAGVIvt5PV+/qza+EaJp4H9iQxvTGp66anyiqtfuz/RkYXz5JzN6bTd7D00g5Cg6jRvdN/lM3K3Vq1a/Pnnn3Tp0kVbl52dTVLRpEcVyfOotidoWK9ePcaPH3/efwi3bduW1157Tfv3qFGjtJyhZbF06VItANGxY8cLzpVoMBgYNmyYFmj57bffiI8v/IF2qTVq1IiNGzcyadIkmjRpoq2Pjo7mqaeeYs+ePbRv315b7+/YR0REMH36dNauXcvgwYO9Htc3Go20a9eOadOmsWPHjhKB3bJatWoVTqfTKw9oWSmKQp8+fUhMTGTChAk0a9ZMy+FMUZB69OjRJCQk8PDDD/sNmAEEBgby9ddfs23bNgYPHuw10jUuLo5XX32Vo0ePlhrYpWiU98GDB3nqqaeIjo7W1kdHR/Pwww9z8OBBHn/8cb+BrvJWPPBXrVo1r7JzURSFoUOHkpyczLhx46hTp45WptPpaNu2LdOmTWPz5s0XHNi9UI0aNWLz5s1MmzaNdu3aefXpuLg47dw+9thjFXpcy6MdjRo1Yvv27cyePZvOnTuXuo+hQ4eW2k+Ln5tXX32VuLg4rSwwMJDBgwezbdu2s157FUXh0UcfZffu3dxzzz0l+vq4ceM4dOhQiYBqceVxPMrSjiNHzv1Lz/mqXLkyCxcuZNWqVdx0001e7YiOjubBBx9k7969vPzyy6UeUyGEEEIIIYQQ/26XzchdIYqP2Jw9ezY333yzbxUhhBBCCHEZkJG75U9G7oqKICN3ZeSujNyVkbsyctc/GbkrRJFt27ZRpUoVGjZseNZ0CqqqsnLlSihKx1GjRg3fKkIIIf4he1KzGDp3PUPnrmdPaulpZIQQQgghhBBClC8J7op/VNWqVQkPD2ffvn2MHz+e1NRU3yoArFmzhi+++AKAVq1aeT1OL4QQ4p9199z1fLcrge92JXD33PW+xUIIIYQQQgghKogEd8U/KjIykgEDBgCwYcMGWrVqpc3onpyczPr163nsscfo2rUrGRkZ6HQ6Ro8eXeokVUIIIS69bSmZfl8LIYQQQgghhKhYV3xwd9y4cSiKUuZl2LBhvrsW5UBRFF544QXuueceAI4ePcrIkSOpWbMmVapUoV27dkyaNAmHw0FISAhz586le/fuvrsRQgjxDxrZsrbf10IIIYQQ5c03/+o/5wJaUqzqBWxVrs4nN6jfOv9Ug8XFuZjzVcqk12fjbwvF79qK47e/nqfCbS9teyvSFR/cFf9+gYGBfP3116xdu5bBgwcTHh6ulel0Opo1a8aECRNITEykV69eXtsKIYT4503qfhXL77yO5Xdex6TuV/kWCyGEEEIIIYSoIFd8cPf5559HVdUyL99++63vrkU5UhSFdu3a8dNPP5GWlqYdd5fLxfbt23n66acJCQnx3UwIIcS/RJdqUXSpFuW7WgghhBCiXF38WL3ydgEtKVb1ArYqV8p5jM70W+efarC4OBdzvi5iBKy/LVS/ayuO3/56ngq3vbTtrUhXfHBXCCGEEEIIIYQQ/3VXTiBHiCudqqrynb0AEtwVQgghhBBCXFpul/ZS/nQrHyoqqupGkQMqSnFRfaP4iD7Vrb10qy4UtbDP4XScqVOM79upxfJbKmrJcgCX6sKlunCXMiDP5XKC24niky1TdTtRfUbxFd+/S3Vp/y78f+F/XeqZa5EbBXfRPkocq6J6qurWjonqLrZPl7PkNsW4KNzehQoup9Y4tWi/btWF6nYVlrmL9q+6z3xG1V3s+KvgLjz+iur9QVXlTN5TzyhKbSuvNp7ZSDuXLmfh4jsa0u0qkZnU7TluatFnKNZPfM+zu9ix1t6/6Eld1GI/D1Q3btza+xdez86MBfWcGwC326n9W6XwuudG1Y6Rqrq8fs4AuIqOhNtzHH3KofCcFvL9xOBSFO18aYrqe9Yr6pl2l9xDIdXtQik6S0rx41Wsj6mFLwr7m6+ida5ix111ObX/e/qaU3WhFPUlf1yqC9xnjoWKglr0WdSiJrndhftVANVd+Lm0+i6n1naX6kZR3bg8bVMUbVsoPM2lt+RMmUst7KPF+6DnGPhu70ZFdbtwFx3L4uWedmjbqm5Ut1Nrv6evF+97xev56xv+qMX6p8v3e1MKt1J4nM/G5fZ/TS3Oc+3yfBcVtTAcLoQQQghxUXadzmLsmj0AvNShEU2iJI2OEOLsck7vJ2v3nwS27MNJeypb8/bSI6yTb7Ur3o6Uv3GrZ/4AvhAOhwO9Xo9Op8OturgqOQSzQ8XeqClq8MVdh4P+OImqP78/UMXlJzM0iYyIZBSd3rfIL9XtYndwD04pVakfvIyCjOM0TskujC0adGA0odjtqIEhGJwqR+PiOOVIw6E6sbnt3BDWAYNy5r2SbKfYkrObyrpA9NmZxGbb2FrJTOeELAoMCqtrhND1eD6JVaJJUHJQVJXKlsq0DGrM36nLScdG/QwXNbJsKIoeVBWlKMC6umY49QLrkG07RbIzkxzVznUn3RgdTlAUEgMVkq16nEYDHRLS2Rdhxh1VnSjFygFnClFYSHZnoaLSNFMlOttWGAgqCpqtqhWOW9HTObGwjieAmmnUsTs2lFppuVTOcYDqYn0lA/UyXUSpFlz2XFRFT1KIiSNhATRNzSckz4EesCmwpmogRp2RTvFpuHGDy832yoFkWwPokpCBuXZrbIc2srxGEMEuPVclpoGioOo9n18lW+9iU+UAbAY9nTPNWPPyUZxu3KqdVZXNXJ3qwuKJlqtuHEYjq8Ld5BpApyj0SHKiuFUUvR5s+Zyy6DhYrTL1TucQlWND0enA7SJPp7C+qhWXw8b1yU50ig6304ZiMLIi1oxTb0BVFKpZYjlVkMzVx7IxOd2cDjSwJ9xEm5MFhEY3wli5Ho78dOx7VrA+LpTQrDwa5epZG6GSazKgd7m59lgebr0eXHZSrSZ2hxuIC63PofwEKhsjycpLoeWxDMyqjr3hBlKDArg61YVZZ0Rnd6CisjrGSKbOxdXpEJVr52iElZNhQTQ312J13m6uz49Al36SzWEu8gKs1E23E53vQHU60SkK+ThZWz2EXLeNaFMUFsVMhiMVs97KdWHtAMhN3oEuNRlXfiZbogOwWSy0T8pDUd2gKNgMBpZXNlDVFE2+I4fTaj61s1XqZhQAKorDTl6glU1Vw4jKzKZJjh6XLZcsq4VtkQaCXAZaJWXiBnR6HarLzcrqQZiNAcTpI9lnP071gOrEpWZhyMnGFRTMMcdp4iOCaJmYRpDegtEUhLleR+17uD1jCycpoEVCCoF6M6rTiRuV9AAD2ypZMKsKLlSMejNNE04Sqlhw2/JQjGZUlxO92YrTkY+CwrEgPUcDdXQ4YQe3i2OhJhJCzEQEVibNdhoHKnoUAlQ9RlMQua48eoR31trikePKY1nGBqJ0AVjTThPiNnAwxED74zksrxFCrlqAVWdBpzOgKApO1UWwaiA8O5/AfBsnA/XUzXazNsZCiDGI2DwX1U6msaJ6EEaXSttjWaBT2BJlJNWip1amnTrpNnA6sBn0rIsLx2Cw0DE+FTdwPNJASGR9Glfu5ttUTZYtjTnJs+kY0oqtubuJMISR7cxFVVQiDCE0tdbX6i47vRK3YiTcEEKqOwO16GZDj/Azv/cUuG0szVxPaH4m9SJbU7dSB63Mny0pyzlUcJwgt5tAp11G7gohhBCibO6eu44fdh/lh91HuXvuOt9iIYQoISAoFsXpJDSkJhHBcbiB0JCa/7nFYgoho+DkRS25rjSy7KfIKDhJlu00xoBwdBnpBAVWKfE+57voVSOGEwWyXKGL0RSMI/cU9uwT57U4ck8RYo3C7jYTaIlEsYSAy4Vqy8WkCyAovC7kZRMaWgujwUpwQDRGgxWzMZAMZxahIXFe/SsiuAbZ7lwiQ+tiCAgjAANZOheqIx9zQT521Y1e1RMRGIfJFESGWkCotTKhITUJC6yOze3AYgnHYg7DaAxCrzOgOmwoTgdOnY7o4FpEhTYg3Z2HS3VjDamKJTAKi7US+vw89Hoz6djQOZ0YXG4CLZFUiWyOE5XI0DrkuPLJc9uwBsZgsoShV0yo9gJw2tHrA8hwZmFQdQQERBEQFIPFWgmrzkKqI53AoFhMxlDIywZFT0BQZVyZKQQExRIQGE2gPoQMZxbBgbFYA6MxGoOx2OzYVCf5bhs6t4pJH4jJHEpYYBUynFnoXC5CwxugOJ24FB1ukxXVYQOXC6M+gIDgWNSCXEwFNrJ1haMmQ0JrYraEo3M7MSpG7AYFiyUCizUaS2A0JlMoxsw0cgwqKiou1Y1Jb8UaUpWAoFhw2LAoZlId6QQFVsZoDMJoDMVgsBLgdJPtLiBPr6J3u3HnZ4HTgeJSsZjDyVELsKlOAs0RpLnzMaoKisuJMS+PTJ0Ti85CUHgtQkNqEhXdEtWWS0hALDa9Dld2KgHGYHJUG7l6UJ02DDoTisuFMS+XTL2bKsF1yHXlExIQzSmdHZNbh94QgMHhJF3nwBpcFWtwDRS3igEdgZZIVCAgIBKTJYxASxQ6vZmqUS3Ic+cTGd0MV/ZpwgMqk+rOIdAajckchtEQAC4X5vx8st0FuHGj15uJCKhMhjsfszFI69ORlZrjyk5Fp7NgdDg5rbOjOOzo0KPXBxDgVLHhJjSwCtGh9bDhRGcOwmiwotryUXR6Atx6MpxZ2A16XNmpGExBGHKyyFQcEBBc2P9UFTU/F8WtYrKEYLVEUiWqJZmqjXBrVSIiG2PQW4iIbEyAzkqGMwuLKRhrSDUMhgCv72FUWH1SHekE6C1Yg6thsUZhtoQSoAvA7naQ7s7DZAwi1ZGB0WjFnZuBooIOBYPegis7FbM5DIMxEENBPmlGFcXpQKcqGAoKyNA5iLTEkuHOx6Azk+XOJyyoGsHmCNyKUuLnTmhITWLC6pPtziU0qDp6czCB+mDyjDpUWy6KwYwTNzZcOFQXiqIn31VAsDUWa1AMFsWEw2TC7IJcdz56vZnwsHrgcKDozaQpdnQuF2ZzGOFB1ShQHRhMgWAvQHW7MNvt2BQVh+pG51Ix6wMINEdh1Af4/uriJcQcQa4rn6DAWFyKQqA5AoMhAIM+gNCAwuuWZwkPro5ZH4TVFEGAMQSTMQA7Lp9j0IACtx2TwYrFEOT7diWEBVYj1ZGO1RyOS3VIcFcIIYQQZbPjVKbf10IIcXbyAKEQlwN/39TSJjK6LB8M/gfb7BkB7M8/16rzoPx7Q0ml9c3L23l8pn/yc1+C91YUBd0F9DtPk86naWf7Hl6of6r/nf+REUIIIYTw4+Gr6vp9LYQQQgghhBCiYklwVwghhBBl8tENLVl11/Wsuut6PrqhpW+xEEIIIYQQQogKIsFdIYQQQpRZx6qRdKwa6btaCCGEEEIIIUQFkuCuEEIIIYQQQgghLpp6CTPEXsr3uiBnS7V5trLydpYcwmcp8k+5xG3/p13wAboyXA6n+HLI5/0PpdsFCe4KIYQQQgghhCgvw74fi/LYtczZuRqAHFs+XT9+nMBnerApcZ9v9X+VtLwsnv71E05mp2vrNiXuI/CZHgz7fqxX3X9aVkEuz8/+jPDnb0J57FpMT3Zj4Nevceh0km9VIYQQVzgJ7gohhBCiTLafymTw72sY/Psatp/K9C0WQggh/vVy7QUM+noMM7cuw626fYv/VRLSTnL1+BG8t+hHIgNDuK99bzrWasrMrX9z9fjhrE/Y47tJhSvP2ebP5VK+1wU528DCs5WVt7MMHzxLkX/qJW77P+2CD9CV4XI4xcplcG7+ycHFEtwVQgghRJkMnbuOn/ce4+e9xxg6d51vsRBCCPGvp6oqTrfLd/W/jtPt4s2F33Lg1DHevOl+9r08ja/ueJ6/H/uImfe9QVZBHo/N/JjM/FzfTYUQQlyhJLgrhBBCiDLZnZrt97UQQgghyldWQS7HMk7RJLYW97bvjV535k/6m5p0oGejtmxK3MeO5MNe2wkhhLhySXBXCCGEEGXyyFV1/L4WQghxZfDkd636yq0oj12L8ti1BD3Tg+E/judYxinf6hUm32Fj0vJftHaYnuzGE79MYl/KUeJeG0TXjx8nx5Zfon7dN27X2m16sht9Pn+B3SfitXpzdq4m+NmeLDu4laPpJ6ny8gDiXhtEclaqVgdg94l4bvrsefSPX4fy2LU0fOtuZm1d5nein+SsVJ7+9ROvnLi3T3mjxPFKzkol7rVBDP3ubaZvXkLoc71RHruW7p885ZX71yPCGsKCh8az88UpVA2N8iozG4yEWYNwul2k58nNViGE+K+Q4K4QQgghymRit5asG9qNdUO7MbFbS99iIYQQl7Hi+V2jg8IY3rEPQ9v2IDQgiC/XzKX7J0+VCFhWhMz8XAZ8+TKPzfwIh8vJ0LY96FirKR/9PZObPnue07neOd8z83O5+fMXeWzmRwDc174397XvTe3IKszbtYZOEx9hy7EDAFQPj+aedr2ICgzFpDdyx9U3cHvrbgQYzdr+Zm1dRstx97P1+EHuatOda+u2ZF/KUW77+lU+XzVbqwew9fhB2r0/kg+W/uyVE/enzYtp9s49fnPiztq2nLu/e5urqtVj0FXXUTuyCpGBIb7VzupYxinWJ+whwhpC3UpVfYuFEEJcoSS4K4QQQogyaxsbQdvYCN/VQgghLmOqqjJ+8Y8cOHWMcX1Hsvm5L/liyLN8e9dLxI+Zzj3terH35FGWHdzqu2m5+2zV7yzYs55u9Vtz4JUf+Paul/j7sY9Y9eQnpOZmkWcv8Kr/3YaFLN6/ieEd+7Bn9Hd8dcfzfHXH8+wZPZWxNw8nIz+HmVv/BqBl1bpMuu1xmsTWonJIBBNueZh3+z5IWECQtr9cewHP33AH8WOma+/90z2voSgKU9Yt0HLc5tjyeWLWJBLTU3it1z3aexfPiXvfD+NI8RmVm2cv4PXe97J01IdMv3cMnw95BoNO71WnNC63m63HD9L3ixc5eOo497TrScPoGr7VhBBCXKEkuCuEEEIIIYQQooTUvCx2JB2mQXQN7mzT3Wu2cqPeQMOYwgBiRY/cTc/LZubWvwkLCOLDW0cRGhColXWs1ZQXut/pVT/PXsCa+F1UDongwU59MeoNWpmiKDSNrQ0X2O6GMTV49JoBXvvq1qA1jWLiSM5KJc9RGFxeeXgHyw9to11cY57sOsir/oAW1/BQ537sSj7CkgObtfUUpVu4pXmXC54Rfv7udRieuI6rikYVP3fD7YzrN/KC9yOEEOLyJcFdIYQQQgghhBAlRAWGsuzxj9n78ndEB4VxPPM0f+7dwHuLfuSGyU/xyryvfDepEEfTT7LvZCItqtalZkRl32Kur9/KK4WC1WRh2tBXSH7rV1pWrUdKdjp/H9jKpyt+45b/jeb2Ka97bX8+qoZWIsgc4LXOYjBRKSgMu8tBvsMGwIpD21BVlb7NOnkFoSkKLF9fvxUAi/Zt8iqrHBJBpaAwr3XnQ1EUhrbtwaCrriMsIIj3Fv1Iz0+f5URWmm9VIYQQVygJ7gohhBCiTLacTOe239Zw229r2HKy5OQvQgghLl+59gKe+GUSlqe6U+2VW+nx6TM8P/sz1sbvIiow1Ld6hUjKTCXblkdcREyJACtAbEgkEdZgr3X/z959h0dR7X0A/87M9k1PKCFA6EWKitIFUURQEWwoVhQB9SKoFxELvoIdLHgFG1ZQFLBxEUSUIk2lCHip0iSUhBbSt83OzPvH7g67k01I2AQBv5/nmYdlzpmzc8pOkt+eOSMrfry2ZBbiR/dBraeuw2WTH8LwLyfh+62/IS2u8uedkZQGp8Vm3A0AKPS49AeYhWYD/7h9LYbNfKXU9unaHyEIAnYdPRjx8LcacUmwmSz6/yuqT8sOmHbHk5h1zzgceO5r3NH+Size8TvGzv8AflUxZicionMQg7tEREQUk7u/X4uv/zyAr/88gLu/X2tMJiKis5Rb9mLgJ+Pxn5+/QsvamXhrwCPY9MQnyJswH8WvLsRDPW4yHlItKhuk1DQNo//7Dh6d8zZqxCXhhb5DsXrUuzj20nfwvr4Ik2962HhIlVu2ayPe/2VeqW3O/1ZA0zRj9irhtNgw5orbkGSPw9IdG5BbUmjMUm00VE+dojmd71Up5a2EUV5aVStnfJWTFJ1wms/971bpBjo3nA1dXF3Xzar0d66Gw+AuERERxWTbscBsJeNrIiI6u/2+fwd+2LoaF9ZtimUj38S/ul2H1ukNIx40djrUSUyFw2JD1vHDEbNdQ3IKc3E8OHMWAP46noNZ65egRlwSFj34Op688g50yGyJVGdCta9FWzepBgDg63ufg/bmsjK3n0f+J+os5FilOhOQYHPCrypQNdWYTGeqah6XROeS6grzns0fQwZ3iYiIKCYjL24a9TUREZ3d8lxF8KsK2tRphFRnQkSaW/Zi6c4NEfuqS2ZKbTRIqY0/Du7C3uOHjMlYvXervuYtgudd6HHhvNoNkJ6QGpFX0zQs3LYmYl9V6ph5HgBgyY711TLTbFP2Hpz3wl245I0HUeAuMSZjb+4hHCo8jjirPeJhbtVNOI1z/07ne1VKed1dXlpVKydCVU5SdNppPve/W6Ub6NxwNnRxdX8xVxWq4ZJfYQzuEhERUUxevawt1g26AusGXYFXL2trTCYiorNUsiMeJlHCyt2bIoKqiqrijZ+/rNYgabg0ZyLu6tAb+e5iPL9wOkp8Hj1tTda2Ug92S3bEI8HmwNp927Hl0F/6fk3TMHvDUryz8r8R+cMVe93Idxcbd1dY10Zt0L5+C7yz8r+YvWFpRIDXLXsx5IuJkB66DOMXfBJxXEXVT66FBJsTq/Zswnur5kaUf6jwOMbMfRc+RcbgTleftjWRiYjo78XgLhEREcXsotrJuKh2snE3ERGdxS7IaILLml2IPbnZOO+Fu3DLx+Mw6LMXkT72ejzz/ce4p9NVMIkSth7aazy0yt3ftT96NrsIs9YvQcNxt2DQZy+i//tPovPr/4Ks+CEIApwWG0yihPrJtXBd225w+Tzo8vpw9H//Sdz7+QQ0fe42DPxkPO5o3wtJ9riIZR6cFhsapqbjuKsQg2e8jDFz3z2lNWtTnQl4++Z/I8URj4GfjEfLF+7CvZ9PwC0fj0OdsTfgw1/no2XtTNzd8SrjoaX8vv9POB/tjcxnbkZOYS4AINHuxJs3jUSSPQ5j5r6Lps/dhns/n4D+7z+JjKdvxIrd/8O9na/Bv7pdZyyOiIjOUQzuEhERERERUSlxVju+GvwsHupxE/yqgtkbluKL3xfjmladsemJjzG2911Ii0vE1kN7kRe25m11SLQ78d9hL+L5vkMgK35MX7MQy3f9gdE9B2L+/RNgN1uRFpcIm9kCkyjh9euHY2L/+5Fgc2DuplX4ZPUPaF6zPn4b9Q5ev/5BNKmRgT+P7MOR4jwgeMvvE71uR6cG5+G3vVsxedk32JObbTyNCrm4fnNsGPMhhnTuiwP5R/DRb99j9oalSLQ7MaHf/fjlkbeQmVLLeFiFdchsiU1PfIIhnfviUOFxfPTb95i7aRVa1s7ErHvG4b1bHoXdbDUeRkRE5ygGd4mIiIiIiCiqBJsTb9wwAr5Ji6G9uQy+SYvx8e2Po3nN+miUWgc5z3+LtY9ORbIjHgAw7Y4nob25DNe27gIEA8Q/j/wPSl5diIvqNTeUXjlOiw1PXXkn8ibMh/bmMuRNmI+X+92HAncJXD6P/jAzALCbrRjd81Y9r/KfpZh//wR0zDwPqc4ErH10KnKe/xaNUuvoxzSrWQ+//vsdaG8ug+u1H9G+fgtcVK85Sl5diGl3PKnnCymvbnWTauD9W0ej+NWF+kPU9o6bjceuuBUJNqeeLz0hFVnjZ0d9wFrovbPGzy61dnC08jc/8QluvvAySCL/zCciisXZsMZvOF71iYiIKCa/H8rD9d/+guu//QW/HwrMgCIiIqoqxV43Lpv8ENq+fE+p2bTHXYV44+cvIQoiejS5MCKNiIjon4DBXSIiIorJ3fPXYs6Og5iz4yDunr/WmExERBSTOKsdfVt1wabsPWjx/J247v2nMGzmK7ju/adQ+6nrsXTnBjxy2QD0aHqB8VAiIqJzHoO7REREFJMdeSfWWQx/TUREZDRh0ecQRl5aqe27zb/g4R4DMO++l9EhsyW+2/wL3v9lHr7b/As6ZLbEvPtexoR+98MsmYxvR0REdM5jcJeIiIhi8tDFTaO+JiIiMurXpiu+HDy+Ulu7es0giSKuadUZKx+eAuU/S/V1dFc+PAXXtOrMdWaJiOgfiz8BiYiIKCYTe7TFhrt7YcPdvTCxR1tjMhERka5lrUzcdEGPSm0ZiWnGYoiIiCiIwV0iIiKK2QW1knBBrSTjbiIiIvoH0KAZd1Wb0/lelSIYd4QpL62qaWW3TzlJ0Qmn+dz/bpVuoHPD2dDF2lnQN8Lf2JAM7hIRERERERERERGdhRjcJSIiIiIiIqJTJpzGuX+n870qpbyJheWlVbVypg+WkxSddprP/e9W6QY6N5wNXSycBX3zd04uZnCXiIiIYrIm5ziu/Wolrv1qJdbkHDcmExERERERUTVhcJeIiIhiMvj7dZi3Owfzdudg8PfrjMlERERERERUTRjcJSIiopjszCuK+pqIiIiIiIiqF4O7REREFJNHLm4W9TURERERERFVLwZ3iYiIKCYv92iDTfdeiU33XomXe7QxJhMREREREVE1YXCXiIiIYtY6LRGt0xKNu4mIiIiIiKgaMbhLRERERER/A8G4g4jOQNE+qZqmGXcBAAQhWu4z3JlwzlHa8ww4q7NSWWPznPcPqndFPhuh5qhMs1Qmb1n+rvEnjRs3bpxxJxERERERUXVRPUWQCw/BkdEG+XI+CuQCNIlvasx2zttbuBU+cxxgTUSaow7S7HWQaK9Voc2sxSPZmY5kRzqc1mTUKJJgTkiDkJgGwRlvfKuKySoAhHpAXArgSD39W7oNqO0EasRxq+LNX8uO9Qm7cCg1DnJ8GtwJqUiIq4MjCQnIjU9AYXwylPgaqO2sD1tcbdjiasNiT8FhU0NAsiHRuh+CpwS1FAskWwIkqxOaqsDsTNEDuscT4uBTvRAEAcmWZKTb6sAiWvThle/Lg1/zwyRIMLlL4IAJbqcT9XwWSJY45CUno45HQEGcA0WCFzWttSAJEuo66iGrZC98qg/JLh+S/SIEyQzRbIdkjYNkT8DReDuSrKlwK244TU64VRcaeswQvW5oqoqC5GS4RD+c9jTUU+zIddoARyISzAko8hfBbrLDr8qIM8Uh2SUjzi9ANNsg2QLl5yXGI9mWhtpuDVAVQJEBVYXHbgOS0hFXUoI4WYM1uR4OSG4k+4DEuHQo3mJAU5Fnk2BNqIOkEjcsXi8EUYIlviYKEhORaElChlcCoAGagsNOMxKctVDXI8KaXA9yyTHkJTjh0ExIly0w2ROD4S0NkjUBPrsdJTYbFJOEBrIdks8LyeqEZEvAIbOMdK8EyS9DUxUAAoS0+sgXZSgmMyyiBY28VgiqAqgyBIsTxZICU41GiCspht0rQxRFQDLDb49DSXIq/FDQVEmELaE2IJogme04Gm9Hij3QX1bJhnhTPBJLXLCb4+FLrQ1FFFDDK8DurAFzQi2oshtyQTYOJ9iR7NNQK64+sk0e2O0psJlsaOizQ7I4IEhm+FLrQBE0WG3J8EOGWTAjwZSAJI8Mi6LieFIiHNYE1PQKkPx+mCw2CCYrDjtN8MCPdI8Gu1fGcYcJfrsTtW21ccx3DA2VeAiCgENmHxLi6iCuuBhOWQFEMySbE4o9Du7U2pBVGQ6TA5IgIcWSCkEQ0CiuMQDAl58N0WSF31OIo8kJSHTUQobfBhECBLMVJnsSilPSYBJN0KBB1mSkehSk+iWYnMnQZA+UxBrwJNdADR9QK64+FE8hlNQ6EONrwKoJSFdsMDlTYI6vCUGUcDwpHpJkQbI5GYqmQBJE1PBqUH1umKxOHPYfhy0xA8kFhbBAgmiywl7zxDMxst0HYZGsSCkqgQUiVL8PAgSU2KzwJiShjj0DflVGhr0u4grykZhUHwAgWQPtZU1rBNVXAggiilJSIQki0hEPk2RBfmIiTKIFmsWGVGsaSpQS1HFkQNVUKJofNsmOJnGlf9a7FBeOeA/DLJjhKCmCWbLCb3cgw29HXlIi/FCRak2D0xQHq2SFTbJDFEQ4S0ogQYJqsSBRsKE4IQlm0Yx4v4gEP3A8KRFJoc+XpuGwXYRsMiGxxIOaphRAVSBZHChISYPN7ER92QxoGvKtKgSLDTXiGhlPNcJB9wHUtqfjuPc4JFE68bkUJWTY6+r5slx74ff7YRbNKFaLIQkS4s3xpdrigHs/BF8Rkq01kOzIiEgzOuTJgc1kg+p3I160QtD+rrAyERERnRN+PZiL537ZCgB4ust56JyRasxCRBRBUxUIomTc/Y9T5C/CN/u/hAYNA+oNhNPkNGYpU05ODpKSkmC32wEAmqZCEMRA4CmWtp0K4LBx52kytoJTsuiULDmyCNAEqIKCPUW7MbTx/fj2wFdIt2XguHwc+b7juD3zrohjVNUPUTSd2KFp+kxXTfVDEE2n/HlWNRWiEHkzsaqpEDRNL8+n+iICxB7FA5tkA8LHfOA/UKBCEk6chwoV8MsQTVZ9XyCrVuYMY7/mh6qpwffUTgzIsHpDU6ApCgTTifMynqesBYKQmuKDpqoQzYFzllUZZtGs5wunqUqgTgAEKZAn1LaaqgTqGmp7TQVUFYJkgqJ4A8FEyRK1TRVNgagqEKQT5xdOgwYoCjRNiWgrY500TQMEQCj1IdWgqWrgPDUNGjSIgqi3QYii+SFqgCCaTrRneLvq+ZRAPxrSjG0X/n9N8UMTxUB/mwN1CLWdsR4qVIhhN7Grfi9Ek7VUvnAaNPjVQHCuXJoGLWobnWAcfxq0qPkDfS4AYf0ZCN9pJ8Y9AFn1wRzlvGVVhqRqEMPGaTi/5oekavpYK49x/BmFjzvV74UqSTAJJ64bGjT4VB+sYuRnMSTUBnrfR2FsJ2Pe8HPwazJMYWMPAKCpgXpI5tJtpql6O2uqAmgq/FBhlqKfb7jQ+/pVP0zBa2W0sXT06FHYbDbEx5f/5atP9UGCAOlkY82AyzIQERFRTO5dsA4L9hzCgj2HcO+CdcZkIqJSTiUQROXT/9iPtW3D4nhEEYFdRC5hIATTTvXzbAxChvaFl2cMkIQCuwgf84H/lAoKiRBLBXaB8peOMAmmsPcMyxd+jCBFBHYR5TxDQU1BsuiBXQDlBgcFUYIomSOCbaG2EEQpsu0FEYIUaH9JsuqB22htKglSmYFdIBCIFCRTqbYy1kkQhKhBSEA4cZ6CoJ9DeGAXACTBpI8ZvS5R+kLvR0Oase3C/y9IJoiCqAd2EdZ2xnqEB3YB6PU25gsnQCj1/lGV2UYnGMdfWfkDfR55roIgRI57IGpgF8H2KSuwi+BYr0hgF/q5RD9PGMadaLJGBHYRrGNZgV2EtYHxMxzO2E7GvOHnUCqwCwBCYMY/orVZ2LGCKEGQzBUK7CLsfUOBXZxkLJ2MRbRUOrALBneJiIgoVnvyi6O+JiKiChCCG9HpEu3e3Wj7iIjorMDgLhEREcXk3+1PrOMV/pqIiIiIiIiqF4O7REREFJMXL22DrUN6Y+uQ3njx0jbGZCIiIiIiIqomDO4SERFRzFqmJqBlaoJxNxEREREREVUjBneJiIiIiIiolAkrJkD4P6HcLeWlFAycPRBbj241Hk5ERESnAYO7REREREREVKYGSQ0w9KKhpbabW98MTdMwa/MsdH2/K9YcWGM8lIiIiKoZg7tEREQUk1UHj6H37BXoPXsFVh08ZkwmIqKzXPcG3TG1/9RS26ybZyF7dDbubXcv8j35+L8l/we37DYeTkRERNWIwV0iIiKKyZAF6/DjX4fw41+HMGTBOmMyERGdw+xmO0Z1HYUUewrWHlyLvfl7jVmIiIioGjG4S0RERDHZW+CK+pqIiP4ZkmxJiLPEweP3wCVH/hzIKcrBqB9GIfnFZAj/J8Ay3oJbv7wVBwoPROQjIiKiU8PgLhEREcXk0Q7No74mIqJ/hj8O/YH9hfvRIKkBMpMy9f0bD21Ex6kd8fovryPVkYrB7QajS70umLlpJtpMacM1es8kgnEHERGdLRjcJSIiopg8160Vdgy7CjuGXYXnurUyJhMR0Tmq2FeM2Ztn486v74SmaRjcbjDSHGl62sPfP4z9BfvxzGXPYNuIbfjwug/x8+Cf8dXAr1DoLcTgOYNxpOSIsVgiIiKqBAZ3iYiIKGZNk+PQNDnOuJuIiM4B0zdOh/B/Qqkt/vl43DL7FpTIJZh8zWQ83Plh/ZiVWSuxPGs5OtbtiEc6PwKzZNbTbmh5Ax7o8AC2HNmCJXuW6Pvpb6QZdxAR0dmCwV0iIiIiIiIqU4OkBhh60VB9u6D2BQCAFHsKvhr4FQqeLMCDHR+EJEr6MSuyVkDTNPRr0Q+JtsSw0gBBEHB5w8sBAIt2L4pIIyIiosphcJeIiIiIiIjK1L1Bd0ztP1Xf1j+wHh9f/zHyPfkYMmcINuRsMB6iPzDtx10/Yth/h5XaPv3jUwiCgF3Hd6HYV2w8nIiIiCqIwV0iIiKKyfL9R9Fz5jL0nLkMy/cfNSYTEdE5RhAEDLpgEJ7s/iTyPfm44+s7kJWfZcwGAFi2dxne//39UtucbXOgaVwLgIiIKFYM7hIREVFMhv6wDkuyjmBJ1hEM/WGdMZmIiM5BgiDg4c4Po31Ge+zM3YnRC0dDVmQ9vW5CXQDA1wO/hvasVub28+CfEWfhmu1ERESnisFdIiIiism+QlfU10REdG5LdaTihStegCiI+GrrV/juz+/0tI51OwIAlvy1hDN0iYiIqhGDu0RERBSTxzq2iPqaiIjOfT0a9MBdF9wFTdPwf0v+D0dKjgAAutbvivYZ7fHOmncwe8vsiACvW3ZjyJwhkJ6RMH7p+LDSiIiIqLIY3CUiIqKYjL+kFXbfdzV233c1xl/SyphMRETnMLNkxlOXPoX0+HRsObIFk3+bDE3TkOpIxdt930aKPQUDZw9Ey8ktce+ce3HL7FtQ55U6+HD9h2hZoyXuvvBuY5FERERUCQzuEhERUcwaJTnRKMlp3E1ERP8ATVKaYFSXUQCAKaunYOOhjQCAizMuxoZ/bcCQi4bgQOEBfLT+I8zePBuJtkRMuHICfhn6CzKTMg2lERERUWUIGhdAIiIiIiIiOu2K/EX45sCX0KBhQN2BcJoq/iVZTk4OkpKSYLfbjUmx+QjAQePO02QsAMG4k6rKkiOLAE2ACgV7indjaOP78e2Br5Buy8Bx+Tjyfcdxe+ZdxsOIiChGR48ehc1mQ3x8vDGpSnDmLhEREREREREREdFZiMFdIiIiisnP+46ix+c/o8fnP+PnfUeNyURERHSm4/28RERnLQZ3iYiIKCbDfliHZfuPYtn+oxj2wzpjMhEREREREVUTBneJiIgoJgeK3FFfExERERERUfVicJeIiIhiMqZTi6iviYiIiIiIqHoxuEtEREQxeabredh7/9XYe//VeKbrecZkIiIiIiIiqiYM7hIREVHMMhOdyEx0GncTERERERFRNWJwl4iIiIiIiIiIiOgsxOAuERERERERERER0VmIwV0iIiKKyeKsI+g2Yym6zViKxVlHjMlERERERERUTRjcJSIiopjc98M6rDxwDCsPHMN9P6wzJhMREdGZTjDuICKiswWDu0RERBSTnBJP1NdERERERERUvRjcJSIiopg83rFF1NdERER0ltCMO4iI6GzB4C4RERHF5Omu5+HA8L44MLwvnu56njGZiIiIiIiIqgmDu0RERBSzjDg7MuLsxt1ERERERERUjRjcJSIiIiIi+psomgJVU427/z6+sJfmyq2jXmKTjbvoDOJTfXArLsiaDC24DoMCBQDg12TIGvuPiOhsJGiaxtV1iIiIiIiITjNZlWEWzUAw8GYRLcYsZcrJyUFSUhLs9uq7a+Ko6wAW7/vcuLtMyRiJPUdsxt0VYhKBoV0BCMYUqkoFcgGsogU26cS4KfIXAQAkQYRDcoblJiKiqnD06FHYbDbEx8cbk6oEZ+4SERFRTH786xA6f7oYnT9djB//OmRMJiKiMoQCuwAqFdg9XeymOOOucpmlU583JIkM7J4OiebEiMAuAMSb4hFvimdgl4joLMXgLhEREcXkgR/X47fs4/gt+zge+HG9MZmIiIiIiIiqCYO7REREFJPDJd6or4mIiIiIiKh6MbhLREREMXmyc4uor4mIiIiIiKh6MbhLREREMXmyc0vkPHgtch68Fk92bmlMJiIiIiIiomrC4C4RERHFrLbThtrOU3tCOhEREREREZ0aBneJiIiIiIiIiIiIzkIM7hIRERERERERERGdhRjcJSIiopj8sOcQOkxfjA7TF+OHPYeMyURERERERFRNGNwlIiKimPzrp/VYm3Mca3OO418/rTcmExERERERUTVhcJeIiIhicqTEE/U1ERERERERVS8Gd4mIiCgmY7ucF/U1ERERERERVS8Gd4mIiCgmj3dqgaMj++PoyP54vFMLYzIRERERERFVEwZ3iYiIKGZpdgvS7BbjbiIiIiIiIqpGDO4SERERERERERERnYUY3CUiIiIiIiIiIiI6CzG4S0RERDGZvzsHF01bhIumLcL83TnGZCIiIiIiIqomDO4SERFRTIb/uB7rD+Vh/aE8DP9xvTGZiIiIiIiIqgmDu0RERBSTXLcv6msiIiIiIiKqXgzuEhERUUye7toy6msiIiIiIiKqXgzuEhERUUwe69gCxx/qj+MP9cdjHVsYk4mIiIiIiKiaMLhLREREMUu2WZBssxh3ExERERERUTVicJeIiIiIiIiIiIjoLMTgLhEREREREREREdFZiMFdIiIiisl3u7Jx/kc/4vyPfsR3u7KNyURERERERFRNGNwlIiKimDz40wb872gB/ne0AA/+tMGYTERERERERNWEwV0iIiKKSZ7HF/U1ERERERERVS8Gd4mIiCgmz3RtFfU1ERERERERVS8Gd4mIiCgmozo0Q8HD16Hg4eswqkMzYzIRERERERFVEwZ3iYiIKGYJVjMSrGbjbiIiIiIiIqpGDO4SERERERERERERnYUY3CUiIiIiIiIiIiI6CzG4S0RERDGZs/MgWn+4EK0/XIg5Ow8ak4mIiIiIiKiaMLhLREREMRm5aCO2HCvElmOFGLloozGZiIiIiIiIqgmDu0RERBSTAq8c9TURERERERFVLwZ3iYiIKCbjup4X9TURERERERFVLwZ3iYiIKCaPtG+Gkn/fgJJ/34BH2jczJhMREREREVE1YXCXiIiIYuYwS3CYJeNuIiKiCtE04x4iIiKqCEHT+GOUiIiIiIjobJKTk4OkpCTY7XZjUpXJ8xxGvveIcXdUXsWN3MLGSLWlGpMqxCMDreoA/J6QiIjONUePHoXNZkN8fLwxqUowuEtERERERHSWOR3B3cpSNRWiwJtDiYiIwlV3cJc/eYmIiCgm3+w4iJYfLETLDxbimx0HjclERPQPwcAuERHR6cefvkRERBSTkT9twPbcQmzPLcTInzYYk4mIiIiIiKiaMLhLREREMSmW/VFfExERERERUfVicJeIiIhi8my31hAFAaIg4NlurY3JREREREREVE34QDUiIiKKmVdRAQBWid8bExGdDmfiA9WIiIioND5QjYiIiM54VklkYJeIiIiIiOg0419hREREREREZwBVrb51yxXFZ9xV5arz/On08Ps9xl1Q1MDYUbVA/2qaqu8LzxPqf6+vIDJN8UHTlIh9oTyK4oWieAEA/uC/qioH0sLeQ9MCdwiF0gBAQ+Am5PBxF3qthMoIlqkoJ84PAPxKoJ6hdFX1l3rfUFuE8oTqEMoX+lfT1FKfL0XxnWi3sHLD63TivUPpkeWGn1tIeP2h1z8yf/hrVZVPtEXYe+vtpHgj+lxRvHpbG/s40P+B91RU34n3CI4LGM4h0KYn3gd6u2sR9YjWL+F8vkIAgNcbOa5O9EPgeL1PVZ9ehxBjXRDRRoF+DfXhiXPT9DEZEqqrcayoqh+KKp+oZ5TPUUgoLXy8h967vDF/4vz8J8a1Koe1X6BvQuP0RJ5g2VHGU0h4fcLbMUA78d6aEnauofaPfL/Q6/D/V4TxMxtqn1B7heod8XkKew+PNz/i+PBFCkL9Ft6+fsUTvDYF21m/vkVeq05F5Oe0fOF1igWXZSAiIqIqlZ+fD0EQjLuJiOgkCot3IefoMgAnv4aqqgpBECp4vdVQM6UzkhPPMyZUqYLinTh0dHnw/FU0bzjEmIXOcDlHl6OweGfEGJQkG2S5GCaTA36/CwAgiCZomh+iYIKq/6sCUGCS4qCqHgACVM0PSbRA0/xQFBmSZAWgQRSt8PtLIElWaJoMRfHDbI6HqnohiTbI/hJYLInw+0uAYGBH0xRIJjtU1Rd4v+A+szkOfsUFQIAomIL5VUBQYZYckPX3OXEOkmSD7C+G2eSE7C8Jlhc8xuSE7C8GNAGa5gcEQBJtUDU/NFWByRQHRXHDakmEVy6AABGCIEJRPXq+8DpbzPGQ/S5YLUnweHODbQCYJCt8/pLAOSqu4DEqABWKIsNsjodfccEk2eH3u6BpKszmOMj+YkiiFaoq6/sUxQeTZIeiugEg0L6Ky9BOgfYFBAiiCYrigVlyQlaC9Q+et6L6oKoKLKZ4qFogf6AuPoiSBaJggqYBmuYFBAlmk0MP+AmCBEUNBNcCfaFB1RS9TQN1cUNRvTCbnIHxJVoh+4uhaRpMJicU1aP3I6BBEu2QleLguHLrY9Mk2SD73TCZAuPFZHIGxlSwDpqqQJKser3UYL1C+8ySI1jXQPtYLIF+Co1zTUMgr+qFCBNU+CHABEEARMkCv98Fk+SAX3FBEKTAeFc8EEQpWHZkwE7VZEADpGA/SZITiuKCJJqhBr8gMJkcwT4CBEiBa7xohqK49fcKtE1ojMTB73dD0xQIohTs78C4EINjPnScSQq0jyiG2jYQ0BQECZJogV9xB8sWAu0uBfoFmgCzOfA5ESDpwXmzOQFe33FIog2aJsNiToDJFAefXAC/3wXJ5ISqeJAQ1xgZtXqGtUSkPQe+hCwXQYMATZMhifZAfU1x8CtFMJucUBQvJMkKAQJUDVBVD7Rge4c+T6Jkg99fHGjfYIA6UBNAFCT4FZ8+plVNgVlywOcvgqYqEEULLJZ4+OSiYBuIaN7wbsOZVsye/V8Gx1XgGti84WAAQHFxMcxmM6zWwPmGHM79DcUlewEAjesPjEirDM7cJSIioph8uf0Amk5dgKZTF+DL7QegaRpUVeXGjRs3bpXcBFjgk4vgkwtPuvmVYsj+iuX1yUWn5docfv6qqpRK53bmbwJMpcagSbIHA1hm+BUXNCgQRTP8fldgLPpdgCBBgAS/3w1JtOpl+P0uCAgEAxXVq49HSbTBr7gBQQI0AUowqOuTiyBJNj1gFjoHQIOieiEKJ95X1fzBfRZAE4KBZwE+uRCCIMHvd0EMvo8ACaoWKMMnFwbeI5TudwGCAEEMO8bvDgZaNb1MBI+XpEDgFAjk1zRNr2OozqJggoZAvUKBVg0qIIh6nQLBqGC7+l0QEKivpgbfJ3huIsyBWYrBuobaJTAr1wtBsED2B9ot0O5FEIVgvSBCQ6DtReFEv0iiNRCAk+wQRauhr/yBgGZYPypKICAWavtAIBrB87fodQrVJfTeoTYK1MUNQQiUr2mKfq6CEFY/MRA0DZXnk4v0YKsk2SPGpiBaIfuLIAqBY0LtFaiDove13+8CBBP8iheK6ovY5wvrV1EItEmoPE2TIQVfIzieBCEwFkJ5QscAQrBugN/v1ssO3wJBbV+wjdyQgsdqmgABIlTVGwwem4JlaoH3CvZV6F8A0DTxxNgP9m+gTQuhQY1oy9D7SMFxGBrPejsEx2+gbBFi8NxDfaaqPgihdkDgIxEqX4AQbOtAvUQxcD6qKgc+q0qg34zXmfAtcL0ohBistxQ6LjRmgtckQIQk2vXrhCiYoQH6OAofY36/KzjT1xcMdiNiTAf60hwMAguB6w0CYzcUJjWeZ0U3SbLBJNqhQQzOXg/s1zQt6s9hUTBDEM1R0yqzMbhLREREMXl48UbsyivGrrxiPLx4ozGZiIiIiIjOSBW5+4POdAzuEhERUUxc/hNrd4W/JiIiIiI6E3BFUjqXMbhLREREMXmuW2uYRRFmUcRz3Vobk4mIiIiI6IzEoPe5gMFdIiIiismD7ZrA/egNcD96Ax5s18SYTERERERERNWEwV0iIiKKmSQIkCr0xHYiIiIiIiKqKgzuEhERERERnaFyjxXhpn6T0K3DuApvM6avNBZDRERE56izKrjrdrsxceJELFq0yJik27FjB/r27QuLxQJBECAIApo0aYKtW7cas1aKpmmYOHGiXmazZs2QlZVlzBaVLMsYNmyYfuzEiRP1xbx///13OJ1OCIKAQYMGGQ+Nya+//gqr1aq/78yZM41ZyjRhwgT9uFPZJkyYEFHeoEGDIAgCnE4nfv/994i0c9mMGTMgCAKuueYaeDweY3IEWZZxzz33QBAEfPfdd8bkqHJzczFx4kQ0adIkov1bt26N1157DYWFhcZDSsnJycGoUaOQkZGhH5+SkoJ77rkHf/75pzF7VJqmYevWrbjnnnuQkpKil5ORkYExY8YgNzfXeEgpsdblm2++iTiuvK1Hjx4oLi42FoG1a9fC4XCUyh9ty8zMRE5OjrEIAEBhYWGpusTFxaFv375YvXp1mYv5n+rnLrw+VVFGuJKSEnzyySdo27YtJEnS8zdo0ABjx44tsw3Cud1uTJ8+PaIMSZLQtm1bTJ8+HW6323hIKYqiYP78+bjkkksiyrjkkkswf/58KIpiPCRCVfVt6Dy6deum/5w5lbr88ssv6Nu3L+Li4vT3jYuLw8CBA8sdIyGapmH16tWlysjIyMCoUaPKPP9wZdWFiIjOLKIkolbtRNSsdWKrUSMBohi4WyI+3h6RVrNWIpxxNmMxREREdI46a4K727ZtQ7t27TBmzJgy/3jesGEDOnbsiPnz50OWZX2/1WpFzZo1I/JWliAIGDFiBPr27QsA2LlzJ8aNGxfxPtFomoZJkybh/fffBwAMGDAAjzzyCIRqvnVV0zTMmDEDPp9P3/fJJ5+U2XZU9TRN07+I6N69O2y2sn/JDo2TTz75xJhUpvnz56N+/foYM2YMdu/eHZG2ZcsWPProo6hXrx4WLFgQkRaiaRqmT5+OzMxMvP7668jOztbT8vLy8Mknn6BFixYYM2ZMuePc7XZj1KhRaNWqFT755BPk5eXpadnZ2Zg4cSJatGiBJUuWRBwXLta6IPhFSax27NgR82dkyZIlaNy4cam6lJSUYP78+ejUqRNuvfVWFBQURBx3Jtq4cSNatmyJe+65B5s2bYKqqnpaVlYWXnjhBWRmZuLtt98uMxh54MABdO/eHYMGDYooQ1VVbNq0CYMGDUK7du2wbds246G6goICDBw4EH379sWqVasiyli1ahX69u2Lvn37ltumVdG3hw4dQu/evdG3b1+sXLlS/1yE16VNmzbYuHGj8VBdqC5du3bF/PnzUVJSoqeVlJRg1qxZJx0jsixj9OjR6NSpU6kysrOz8frrryMzMxPTp08vs1/Kq8upmLV9Pxq9+z0avfs9Zm3fb0wmIqIYJCc78dbUwfj6u0f0bcaXD6LtBZkAgKfGXReR9vV3j+C6Gy42FkNERETnqLMmuDt37lxs377duDvCxx9/jPz8fADAzTffjB07diA7Oxtz585FUlKSMXul2e12TJ48GY0aNQKCwdJZs2YZs0VYu3YtXnrpJQBA06ZN8corr8BsNhuzVbkjR45g8eLFAIAGDRoAAJYuXVpu0KEsvXr1wtChQyu1tW3b1ljMP05ubi7WrFkDk8mErl27GpN1oeDomDFjjEll+umnn9CvXz+4XC4AwJAhQ7Bp0yZkZ2dj1apVuPPOO4HgDNLbbrsNa9asMZQQmOl6zz33QJZlmM1m/N///R927NiB/fv344svvkDjxo0BABMnTsSkSZOiBolkWcZDDz2ESZMmAQDq16+Pjz/+GPv378fGjRv18zh27BhuueUWbNmyxVBC1dTF7/frs4zT0tJw9913lxqT4ds111wDk8lkLEaf4W+xWHDbbbeVOi58u/XWW2G32yOOX7NmDW688UYcO3YMMNRlwYIFaN++PQBg1qxZGD16dKmgedu2bUu9T7Rt8ODBSE9P14+7+uqr4XQ6q6wMBIPqvXr1wv79gUBdv379sGrVKmRnZ2PTpk0YMWIEzGYzZFnGiBEj8M033+jHhhQUFODuu+/GunXrAADt27fHggULkJ2djeXLl6NXr14AgO3bt+OGG27AgQMHDCWcCGR+9dVXQLCM5cuXl2rTH374AUOHDi3VpiGx9m1BQQHuuOMO/boaXpdNmzZhyJAhAIDdu3fj5ptvjnpnh9vtxh133KHXpXHjxvjiiy+wf/9+7N27F++++y7q168PlDNGQl8EvfbaawCAhIQETJkyBXv37sWOHTvwyCOP6P1yzz33RL0L4GR1ORWPLN6IvwpK8FdBCR5ZXPmfM0RERERERHRqBC1axOYMNGHCBDz++ONAMNB77bXXRqSXlJSgf//+WLx4MTIyMvDbb7+hbt26EXmqytdff42bb74ZqqoiKSkJS5YswYUXXmjMhqysLPTq1Qs7d+6ExWLBokWL0K1bt4g8v//+O7p37w6Xy4W77roL06ZNi0g/VTNnzsStt94Kk8mECRMm4Omnn4bL5cKIESPw5ptvGrOXcrL2ppNbtGgRrrzySrRs2RLLli1DWlqaMQu2bt2Ku+++G2vXro3YX16bFxQUoHfv3li9ejVEUcTs2bNx4403RuTRNA2zZ8/GbbfdBlVV0bFjRyxcuBCJiYlAMPh/+eWXY8uWLUhKSsLChQvRoUOHiDKOHz+O/v37Y+XKlUhKSsKKFSvQunXriDzffPMNbrrpJmiahssuuwxfffUVUlJS9HRN0/DKK6/ogevhw4dj8uTJ+sz1qqgLgsHjSy+9FFu3bsWNN96ImTNnRg3elsfj8eDGG2/E999/j4svvhg//vgjkpOTjdnK5Ha7cfPNN2PevHll1sXtduOBBx7AtGnTIAgCvv/+e/Tp0yciT0XMmjULt956KzRNw9ChQ/HWW29V+kuj8srw+/2488479aVcJkyYgNGjR5e64+CXX37BNddcg/z8fGRkZGDFihVo2LChnj5lyhSMGDECAEq9B4KB2+HDh+t3NjzzzDMYN26cng7DGItWhtvtxogRI/Dhhx9CEAR89dVXuOGGGyLKiLVvAeDdd9/FAw88AJRRF03TMHPmTNx+++3QNK3UWEfYdRkAevbsia+//jpiHCP4ubvpppuwdOlSAMAXX3yBgQMH6umbN29Gt27dkJ+fj6ZNm+Knn35CZmZg5lZI+M+oaJ+Xk9XlVKT857/I8wTuFEm2WbD7ru7GLEREVAElrgPYm136C1Mjt8uHx/79OTau34uXX7sVXbs1N2YpJaNmLyQltDTurlLFrv3Iyv4WAGCS7GjecKgxC53hjuT+hqN5kZMp4p2ZKCrJgt2WDrcnB5JkhckUB6/3xNJrFksyNFWF7C+A014XJe4TX9rbLGlQVA9k/4klwEJ5rNZUaIoXPn+xvi/OUQ/Frv1wOuqixBUoxyTZ4VfcsFlrwOM9CgAQRQtU1QeHLR0+uQB+xQWLORE+uQBmU2LEudgsqfCrHvj9gTue4pz1UVyyD05HPZS49sNiSYSmAbJcoO+TRBsABRBEKIpXPweHvQ5c7my9bJPkhCRaISvFUNXA70N2axr8qheyXASnPQMl7oNw2GvD6yuAogTuJotz1kNxyX44bOlweXJgs6TC48uFWXJCVkr08h3W2nD7jkLTlEBd/QXw+10QBRNUza/3S5yjPopd+wAADltduDwHYDEnQVF9UBQXHLYMuDwHA+n2dLjcgWN8/iL4fIG7L22WNMhKIRTFF5FfFCSomgogED6yWdIg+4ugqF69PQDAZq0Jj/cIAMBiDkyw88n5ep/abTXh8+VBUX0AAr8nO2zpcHsPQ9NU2G214fYcQrh4ZwMUleyF05GJEteJCRShfgq1b+j/gToUQVG8el6rNQ1e7/Hg+QfqEHovk2iHX3WH9VOgPqJohtWcArf3MCzmZPjkPH18hfKE2shsToAgmOCXi6FqPthtteD2HNbfP0QQJFgtafB4D+vHBtpJg08ugNVSA6rmhSwXwmyKg+wvht1eB253Nhz2DLjcB2ExJwTGqr8wYuxLkhWK4tX/DdT5GJy2DJR4DurjKVSHELMpDibJGaxnEkRBgseXq/eZKJhgs9WCy30QZlM8oGmQlWI47BnweA5DEExQVA8ctnSYzfHweI/C5yuA1VoDHu9h1EjpgJopnSLaIdz+Q/NQWLwHVksqvL5cOKy14fIegtNeHyXufXDYasPlOQS7rSZMog2qpqHEvR9WSypkf5H+mdP731YTHs8RiKIF0FRAECCJNsj+osjPgC0dLk82RMECVZPD2icwbptm3hVxnhW1/9B8qIoMr78AmiqjecPAJKCSkhKYzWZYLJaI/Ln5G1Dk2gtV8aFRvVsi0irjrJm5ezKapsHv9wMAGjZsWOoP5qrUv39/DBs2DACQn5+P0aNHl7p9VpZljBs3Djt37gQAPPfcc7jkkksi8lQXj8eDTz/9FADQrFkz3HTTTXpQ+aeffsLhw6UvMlT1li5dCk3T0LNnz1KB3ZycHDzwwANo1apVqcDuyaxdu1afvfrAAw+UCmQhuIzIjTfeiJtvvhkIfokQPiPvt99+02fRPvTQQ6UCuwCQkpKC+++/HwiO8x9//DEivaCgQF8/Oj09HVOnTo0I7CJ4Hvfddx86duwIAJg3b17E8g9VURcAOHjwoD7DtH379pUO7CJYn9DdAa1bt670bP+tW7fqS0/cfPPN6N+/vzEL7HY7hgwZApPJBE3T9BmclbFmzRrcf//90DQN7du3x0svvVTpwNzJyti5c6fe371798aIESNKBXYBoEuXLnjooYeAYB/88ssveprf78fPP/8MAEhKSsLIkSNLnafZbMYTTzyBjIwMAMDixYtRVFSkp7vdbkydOlUfY4899lipMux2O5555hk0adIEmqZh6tSppZZfiLVvi4uL9UB3RkYGnnjiiVLnIQgC+vTpg4suuggAsGDBAhw5EvjFFobrclJSEt54442oP6dSUlLwwgsv6D/0v/rqK/1nGwD88ccf+h0qI0eOLBXYRfBnVOjz8r///Q87duzQ0ypSl1PxfLdWsEkibJKI57u1MiYTERERERFRNTlngrvhQg/bqS4mkwnPP/+8fjvw4sWLI9ac1Azrp56udXZDdu3apQdZOnTogHr16ulrBW/fvl2fEUbVp7i4GKtWrQIAXHHFFaXSbr31Vrz77rv6vp49e0b8vzxr1qyBpmkwmUwYOHBgmePKZDLp/e73+yPWwvV6vejZsydq165d6vzCtWjRAg6HAwh+YRFu69at2LBhAwBg2LBhaNKkSUR6SGJiInr27AlRFBEXF4fjx4/raVVRFwRvhS8qKoIgCHpwrbL27dunP4SqU6dOZZ5LWTweD3r16oXk5GT07du3zABz48aNUadOHSBKm55MQUEBnnzySeTn58NiseC1115DamqqMVu5KlLGtm3b9H668847Sy1REO6KK67Q6xq+rILH49GXp0hISCj1HiE1atTQx86+ffsiHuq2d+9e/cuPLl266EvMGNWrVw9XXXUVAGD16tWl1m2OtW9LSkrQtm1bNG7cGF27dkW9evWMWQAAycnJOO+884DgOA1fv/bYsWP4448/gODnvUWLFnqaUYsWLfSlbY4dOxbxMMbwNo4W2EXw83LBBRcAwQD5oUMnZj9UtC6V9a92TeB+9Ea4H70R/2oX/VpARERERERnmsr9bURnpjM+uBt68ntoiQAE134Ugk80X7hwIZxOJ+Lj47Fs2TIAwLJlyxAfH6/nqchTwysrNTUVU6ZM0WeAjRs3DitXrgQArFy5Ek8//TRwmtfZDfnqq6+Qn58PQRBwyy23QBAE9OzZU59V+Xc8WG3QoEEQBAFOp7PMB1+53W5Mnz4drVu3hhB88ntycjKGDx+OAwcOICcnB5mZmRAEARMmTDAeDgQD66tXr8bAgQORkpKil1OrVi3cf//9+rqs0YTOsUePHiguLkZOTg5GjRqFWrVq6eU0aNAAzzzzDAoLC42HR9i/fz82bdqE2rVrl1rKIFxCQgI+//xzLFy4UA/4ncyxY8eQkZGBpKSkcoNuCAaSoxkwYAAWLVqEnJyccmeUb9++XV8L1ziGf/31V/h8PphMpnIDxADwwgsvQFEUbN68GW3atNH3V0VdEPYwtVq1aulrYldW6IFbJpMJrVpVfuZh165dMWfOHBw/fhy33367MVm3e/duffaysU1P5oMPPtDXSX3ggQfK7buyVKSMI0eOoF69ejCbzUhISDAmR/B4PFAUxbgbJpNJX8PX5/OVec3xeDw4ejRwa13NmjUjHjy4a9cuPcjcq1evMgPmANCjRw8gOMs8FEQNibVva9WqhTfffBO7du3CrFmzyjyPvLw8fW1fk8kEUTzxIzYvLw8pKSlITk6G0+ksswwAUBRF/9wZhc86Lu86dPBg4HYjk8kUsQRFRetCREREREREZwf9L88tW7agRo0aehCrrK1p06b67c//dB06dMDYsWOBYPBi1KhR+OOPPzBs2DD4fD5YLBZ8+OGHZc6uqg7Hjh3Dl19+CQBo2bIlLr448KTcpk2b4sorrwRieLBaddq2bRvatWuHQYMGRTx0Kz8/H2+//TaaN28e9cFA4UpKSjB48GB06tQJs2bNipjdeeTIEbz33nto0aIFnn/++ZPOmJw3bx6aNGmC119/PeLW6qysLDz77LNo3Lix/pCoaJYvX47jx4+jXbt2UYO2oQDLoUOHcOutt0KSJGOWMr3++us4cOAAjh49Wu4s1fDb4gVBOGnw1Oj48eP6bOKkpCR9/CAYRA8FVOvXr4+mTZvqaZVRFXXxeDz6eK5duzbmzp2LTp06IS4urlRQPjf3xLpgRqtXrwaCsy/Xr1+PPn36lPqCYPjw4fjrr7+Mh1aY2+3GBx98AL/fD0EQcNNNNxmzlGnXrl36Q7SaNGmCUaNGQajkDNSKlnH//fdj37598Pl8Za79HLJy5Ur9roXwwKPNZtNn0x46dAizZs3S84X74Ycf9IBo165dI8oI7UdwnJWnXr16+rgwPszsdPStpmn44Ycf9M/FVVddhZo1a+rpbdq0wf/+9z8cP378pGur79y5E7t27QKAUoHg8DZ67733Si0JhGA/h5b8aN68OZo1a2bMQkREREREROcIPbh73nnn6etrlmfUqFFVdhtnRQwfPhzZ2dl6EBUApk2bhuzsbKxduxaXXnop9uzZg507d6JTp8AizZ06dcLOnTv1POF/YFe1kSNHYsCAAUBw/dDLL79cX9vxdK6zG7Ju3Tps27YNCM7ODK31ajKZ9Af5+Hw+zJgxI2qg5e+QlZWF/v376+3Wq1cvLF++HNnZ2Vi+fDl69eoFl8ulB5yiKS4uxoABA/SlMOrXr493330Xe/fuLfUU+qeffhqTJk0qs/6//fYb7rzzTsiyjBEjRmDTpk3Izs7GnDlz9Fudjx07hkcffTTqbFK/34+ffvoJCM4yD5+JCABxcXGYNWsWRowYUemAa2X8/vvvmDt3LhAM9IfO/WRyc3PxySef4MILL9Rnoz/xxBMRMx5LSkr0mYH16tWD0+lEYWEhJk6ciCZNmkQEzUaNGhXz7Pny6hK+nurGjRvxyCOPYPXq1SgpCTwsAWFB+UaNGmHBggX6/pCSkhI9kHj06FE89NBDWLhwYakvCEJfNIQvw1IRxcXF+O9//4suXbrogb0hQ4agZ8+exqxRaZqGKVOm6O14KtfhqijDaO/evfo6sikpKejePfIhWnfeeadexyeffBJPPPGEvuZ3QUEBXn75ZQwePBgIfgFlXL4mtGa5w+E46XVcFEX92PAZ+tXdt7Is448//sCgQYNw2223QdM0NG3aNOoD6CpClmVMnToVPl/gwQDGa0irVq3wxBNPAABWrFiBfv36YdOmTVAUBbIsY+HChejTpw9ycnIgiiLGjx9/0rYLCdWFiIiIiIj+KSr+tw+dufTgriAI5a6bieCDikKBzNMlLi4O6enpiIuL0/clJycjPT1dv4W3Vq1aqF27NqxWKwDAarWidu3aep7KzIqsLLPZjFdeeUWfuRi6hfh0r7OLYFBx2rRp0DQNFosFvXv3jkjv2LGjvs7jnDlzsHfv3oj0soSWwajoNmjQIGMRZQqtTxwK4rz44otYsGABunXrhvT0dHTr1g0LFizAiy++WG7A5bPPPtODdpdddhk2bNiA++67D5mZmcjMzMR9992HDRs24LLLLgOCAd5Q4NLI6/UiLi4OP//8M9588020bt0a6enp6N+/P37++Wf94XSrVq2KOgP68OHDWLduHRwOhz5z+nQrKCjAU089pd/aPWLEiFIPdTNauXIlzGYz0tLScM8992Dfvn2oUaMGZs6cWSpQFX4rfVpaGv73v/+hdevWGDNmTMR6p0eOHMHrr7+OJk2aYP78+fr+yjhZXQ4cOBCxpmj9+vXx1ltvYceOHcjOzsaCBQvQq1cvIHgbe9++ffH111/r+UP7Q2MQweUynnvuOT2wv2rVKtx5551AMAA2fPhwvPLKK+WOSQDYs2cP0tPTER8fj+uuuw4bN26E3W7Hq6++irfeeqvCyzJs2bJFDwq3atUq6oPnTqYqyggnyzJefPFF7NmzBwgGco3ryCYmJuKbb77BiBEjIIoiJkyYgNq1a0MQBCQlJeGJJ56ALMsYMGAAli9fXuouh5PNsA9Xq1atiOUHQqqrb/1+P2666SZYLBZccMEFepC7rLpU1Ny5czF9+nQg2E/XX399RLogCBg1ahQ+//xz1KhRA8uXL0fbtm1hMplgsVjQp08f7N69Gy1atMCSJUtw4403RhwfjbEup+LzrfuQ+c58ZL4zH59vjf4lHBEREREREVW9iDV369Wrh1GjRoXv0gnBdW/LeijOP1lmZiZGjx6t/1+SJNx3330VDtxUlfAn3F922WWl/kivVauWHtA5ePDgKQfbqtLevXv124cvvfRSjBgxolQwXpIkPPzww/oDtYyOHTuGyZMnAwDS09MxdepUfX3hcCkpKXjllVfgcDjg8/nw0UcflRnAeeihh9ClSxfjbiQmJurrqUZ7sBcAbNq0Cfv27cN55513yuu/xqKkpAQPPvigvq5q3759KxRwP3LkCPx+f8S+o0eP4rXXXtMf0Bciy7I+a3ndunW45pprsH//fvTr1w+rVq0qFTRzuVzo16+fPvu2oipSF5fLhdq1a8NiseCRRx7B9u3b8a9//QtNmzZFeno6+vTpg4ULF2Ly5MkQRRGqquKZZ56JWG6jqKgIqampcDqduOGGG/DXX39h7NixemC/S5cumD59OubNm6c/YO6VV16JWDYgmtzc3FK3zbvdbrzzzjuYO3dumeMvnKZp+OCDD5Cfnw8AePDBBys8EzOkKsoIJ8syXnjhBbz//vtAMAj55JNPRv0yq7CwMGIWtVFoxq1x7FWWJEmlrh2oxr51u936LORw8+bNwxtvvFFuncuyaNEiDBkyBKqqQhRFTJo0KWo/+Xw+/WF1ZZEkCYqiVGiMlVWXyhi15A/sK3RhX6ELo5Zw9i8REREREdHpUuqBagMGDED79u2Nu3HNNdfo6ydSpKysLLzyyiv6/xVFwfjx40sFdarb4sWL9ZnD119/fdRb/vv27QuLxQIEZ7tW5Bx79eqFoUOHVngz3ppdng0bNui39996660RM7TD2e12DBw40LgbCD70a8eOHQCA3r17o3HjxsYsuubNm+vje/ny5REBvhBBEMpdTiN8Dd1oAaDvv/8emqahc+fOUWcSVqeCggLcfffd+OyzzwAA3bp1w/vvvx91LBi1bNkSq1evLhWYXbt2Lbp3715qtmtIVlYWCgsL8fHHH2POnDno0qWLHjSbNm0aZs6cqQdVn3/++XLXvQ1X0bp069YNu3fvhtfrxeuvv14qHcE+ve+++/RZjFu2bNGXzgCAZs2aYf369SguLsbXX38d9csBALj66qvx73//Gwh+qfD5558bs0SoU6cOvv/+e2RnZ2PTpk0YMWIEzGYzdu/ejZtuuqlCM0TDvwBp0aJFqZmcFVEVZYTIsownnngC48ePBwA0atQIX375ZdQg5KJFi9CmTRt89NFHUFUVd955px78X716NYYNGwYAmD17Npo3b14tXzhVV9+aTCaMGzeu1NIvbrcbEyZMQP/+/St0fQ2ZP38++vfvj/z8fIiiiA8//DDqgwoPHTqE3r17Y+TIkTh69CguuOACzJ49G/v378eOHTvwn//8BzVq1MCWLVvQs2dPjB49+qQzoI11ORU+VY36moiIiIiIAk7ypx/FQEDpiUb/JKWCu6mpqXj88ccjZmA5HA6MHTs2atDkn87tduPBBx/Ezp07IQiC/nT4FStW4PHHH495NlpFFRQU6EGwjIyMiIdfhbvgggv0pQnWrFmDX3/91ZillBEjRmDq1KkV3u69915jEWUKrZVqt9vRrl07Y3KEFi1a6DPrwm3ZskVv57p16+LQoUPIycmJuhUVFaFGjRpAMEhy4MABQ2mBcznVoGxeXh5+/fVXCIKAfv36GZOr1aFDh3DttdfqQbxu3bph9uzZqF27tjFrVC1btkSHDh3KDMw+9NBDZT5w6p577sEdd9xRauamIAi4+eab8cADDwDBQPGqVasi8kQTa12iMZvNGDZsmH6OoVnulSEIAm677TY9QLhq1aqo6y6HZGRkoEePHkhPT0fr1q3x5ptv4ueff9YfiPX000/jt99+Mx4W4ZdfftG/ALnllltQq1YtY5aTqooyEPbQwtBD2Ro1aoR58+ahZcuWxqzYv38/HnjgAeTn5yMpKQk//fQTpk+frgf/O3TogPfeew8rVqxAUlISXC4Xhg4dqj9IrLLcbre+Tu2pqGzf2u129OzZs8ylXxYvXoxJkyYZDytF0zR88skn6NevH1wulx7YHTRoUKnPk9/vx9ixY7FixQoAwDPPPIM1a9ZgwIABqFu3Lpo2bYqRI0di+/bt+nm89tprmDVrVkQ5Rsa6nIoXu7dBnMWEOIsJL3ZvY0wmIiIiihTlji8iIjo1pYK7CD7l+5prrtH/P3ToUHTo0CEiDwX+KJ88eTLmzZsHALjpppuwaNEiPXAzdepU/Pe//zUcVT3Wrl2LNWvWAMElFxo1agQhynq4DocDCxcuBILnP23atNMWgI4m9OAjQRAgilGHo65OnTpR140N3WoOAM8//zzq1KlT7hYKGLpcrqgzd2OxZ88ebN26FfXr148a8Kou27ZtQ9euXfWgz5VXXok5c+bEFAw1BmYPHjyoB0TNZrM+y1oQBAwcOBAmkyni+BBBEHDdddfpgarNmzcbs0SojrqEZGRk6IH7rKyscoN3ZalRo4Z+Ln/99ReKioqMWcrVpUsXPPfcc0Dw9vovvvjCmEXn8Xj0GaTR1tGuiKooA8GA+1VXXaV/idSqVSssXLiwzHH+/fff64Haxx9/POosVATbI7SsSk5Ojl4+guOsovLy8lBYWGjcXSmx9m1KSgrefPNN/WfAl19+We7yCbIsY/z48bjnnnugqirMZjNmzpwZNbCL4NI73377LRBcxubRRx+N2kYpKSmYOnUq0tPTAQBTpkyp1CziU3HfBY1Q9Mj1KHrketx3welfjoaIiIiIiOifKmo0zW63Y9y4cXA4HEhPT8eDDz4Y9Q/Nf7qVK1fi6aefBoJrvb744ovo1KkTXnrpJQCAqqp47LHHTvk214rSNA1z5sw56e3d0fzwww/67Nl/oqoObP/6669wuVy4+OKLT3l2ZGVomoa5c+fi4osv1h9sdd9992HOnDll3n5eGYIgRHzRs3btWgCAzWbTZ0BXZKZzcnKyPvM/FNA3qu66AEBSUlKZS39UVHjdT1W3bt0QHx8PBJf2KGt91r/++kuf2du5c2e0aVP5GZFVUcbatWvRoUOHiID78uXLy30AZ2idZrvdjssvv9yYHKFHjx7IyMgAgtfVUHuEHlRZkS9iwmfuNm/e3JhcIVXRtw0bNsRFF10EBGcvh2ZMGx0/fhwDBw7Ul7dIS0vDsmXLMGDAgDJ/3u7atUtfeufKK68sdyw3btxYn727devWav85RERERERERH+PqMFdAGjXrh2GDh2Khx56qNw/4P+psrKycO+998Ln80EURUyePFlvp3vvvRcDBgwAgjM5hw0bdkozBCtq7969mDNnDgCgfv36+Pjjj/Hll1+Wu4X+6M/Pz9dnsv4d6tatCwRnMLrdbmNyhOzs7Kiz4MJnroUeUlXR7dprr40oKxYejwcLFiwAgusUlzWTtapomoa33noL119/PVwuFwBgwoQJeOutt6p0CZW0tDS9vNDanXFxcad8XYg20zDWuuTn5+PIkSNQFMWYFCF8DKWlpcFms+lpxcXFOHz48EnXJy0qKtKXp6hZs2ZEGRWVmpqKhIQEIPgFQ1lfzKxfv14P5vXo0aPcYF5ZYi1j/vz56NGjB/bv3w+cQsA9JSVFn0Falvj4eNSrVw8wtEfbtm31PPv27dNfR5OVlaV/WdOiRYuItNPZtw6HQ/9iR1GUqF8gHThwAL1798Y333wDBGdB//rrr+jcubMxa5mMdTQSBAHNmjUDyjkPIiI6e9kdFkx+926sWDMOXbud2peaREREdG4oM7grCAKee+45jBw50pj0jyfLMkaPHo2dO3cCAIYNG4b+/fvr6WazGa+88oo+6+ynn37C22+/XWYAJ1bh62lee+21GDRoEG666aZyt8cee0yfHTZr1qyYn5R+qs477zwgGNDZsmWLMTnCwYMH9cBfuLZt2+p1Odn6pdUpOzsb69evR0pKSqUeKncqNE3D66+/jhEjRkBVVTgcDvz3v//F6NGjIUmSMXsEj8eD4cOHo27dumjRokWZMwtDduzYoQfew2dEhr4gcLlcJ539Hd53oT4PiaUuJSUluOKKK5CcnIzmzZtj27ZtxiwRtm/frp9H+/bt9QD8kCFDEB8fj/T0dCxevNhwVKTdu3cjOzsbANC6dWv9FvyXX34Z9evXR3p6On7//XfDUZH27t2rf+YyMjL0tbqNVq9eDQQfeFXWsgYnE0sZ33zzjR5wF0URb7zxRoUD7iHHjx9HTk6OcXeEoqIiPXhsMpn0z3ODBg30IPJvv/1W5jVU0zR9/fD4+PiIhypWRd+uXr0aLVq0QEpKCiZMmGA4KlJ+fr7+gMfk5ORSM/izsrLQp08frFu3DgBwww03nHQWdDQn+8xpmqafhyRJ+livTF2IiIiIiIjozFdmcBfBP5Ir80f8P4GmaZg0aRK+/PJLIBggev7550vN0szMzMRLL72kryP79NNPY+XKlRF5qoLb7cann34KBIMiAwcOLPOW3nCdO3fW11Hevn07li5dasxyWnTp0kW/Hfvbb78tc/auLMsRa3GGa9OmjR50/Oabb8oNVOfm5qJDhw5ISUlBp06d9IBSVdi8eTMOHTqENm3a6LMQq8s333yDxx57DAguNfDTTz+hX79+Fep7m80GSZJw8OBB/Pnnn+X2vdvtxsyZM4HgFz7ha2+H9937779f5uz08L6zWCylZifGUhen04mOHTsCwaDa/PnzjVl0BQUFeO+994DgeYQH4EOBak3T8NVXX5U5y1GWZXz44Yd6eu/evfXzTE5Oxv79+3Ho0KFyl0kxvkf37t2j1rWoqAgbNmwAANSqVQsNGjQwZjmpWMpYs2YN7r33XsiyDFEU8fnnn2PkyJEnDbiHdOnSBQiOodCM+rL8/PPP+pcMHTt21IPdDRs2RKdOnQAACxcuxO7duyOOCzlw4IC+lniXLl30Gauoor5NSEjA0aNHkZeXh6+//hq5ubmGo09Yt24d1q9fDwA4//zzI9YJLygowL333qt/kXXvvffis88+q/As6CZNmuh558yZU+557N69W/9sN2/eHPXr1wcqWZfK+GzLPmS89R0y3voOn20pf5Y1ERERERGdKUr/LUpnn3KDu2eS0Ayqv1v4OrtJSUmYMmUKUlNTjdkAAP3798ewYcOA4LIDw4YNw4EDB4zZYrJx40b9D/iLLrqowutpJiYm4vrrr9f//8knn5QZWK1ODRo0wE033QQEgzdvvPFGqVvrNU3DjBkz8PXXX0fsD6lVqxZuueUWIBiofvTRR6OuYaooCqZOnYq1a9ciLy8PDRs2POnt4hWlaZr+oKOuXbtW+tb3yti1a5c+yzUpKQkLFy7UA2kVNWDAAFgsFgDAs88+qz/4KpyiKHjjjTf0BwZec8016Nq1q54e3nfLli3Dq6++Wuq2d2Pf9e/fHxdffLGeXhV1ueGGG+BwOIDg7NnQWq/hQrPtQ2vGDh48GO3bt9fTu3fvrs+c/Pjjj/H111+XCkaGvtj55JNPgODSG3379tXTe/bsGfEAq9BM0nCapmH27Nn6A8RatWoV8TkMd/ToUb1fWrRocdJ1jaM51TJyc3Px4IMPIj8/H6IoYvbs2bjllluiBqHLcvXVV+tt+vrrr+P77783ZgGC17Ann3wSCF5TQ2MKwS8i7r33XgiCgJycHIwfP77UZ9vtdmP8+PF6Pe++++6ILyWrom+bNGmCnj17AsH1h6dOnVrqOoXgeB4+fDhUVYUoinj44Yf1pR00TcOrr76qzx4eOnQo3nnnnUp9gdq0aVN9vKxduxYvvfRSqc8cgkHkJ554Qp8xPXDgQL3vK1qXynp06R/ILvYgu9iDR5f+YUwmIiIiIiKianLWBHfr1Kmjv/7ss8+wc+fOCq2xWZXC19kFgLFjx0bMZDQymUx4/vnn9SDS9u3b8eyzz0b9YxwAli9fjmHDhlVo+/DDDwEA8+bN08/n+uuvR2JioqHUsl1zzTV60Hzp0qXYuHGjMQsAYPLkyaXe/2Tb448/jvz8fGNRpQiCgEceeURfwuLJJ5/EVVddhRUrViAnJwe//PILrrvuOv1p8mV55JFH9IDFZ599hvPOOw/vvfcesrKycODAAXz33Xfo3LmzHkRKS0vD2LFjS824PlW5ublYs2YNBEHQZwpWB03T8MYbb+hBm/j4eLz//vul2j/aFloPGAAuueQSPPjggwCAnTt34qKLLsKLL76InTt3Rm2vpk2bYsqUKRGBKEEQMH78eL3dx48fjw4dOuDLL7/EgQMH8Msvv+C2227T+y4tLQ3PPPOMvuZuVdWlXbt2eOaZZ4Dg7N0ePXpg5MiR2Lx5Mw4cOIAvv/wSHTp0wPvvvw8EA6qPP/54RN/Xq1cPL7/8MkRRhKqqGDhwIAYNGoRffvkFOTk5+OGHH9C7d2+MGTMGCI6fF198MSKI36RJE/3hWMbzyMnJwaJFi3Dddddh4MCBejD7o48+Qs2aNfUywuXm5qKgoAAIBtLLWrqhPKdaxhdffKE/PC8hIQHfffddqT6ItoWuSzC0qcvlQt++fXHXXXdh0aJFyMnJwebNmzFy5Eh06NBBn0H/xBNP4IILLtDLQHCpmSFDhgDBz/Zll12GH374Qe+XSy+9VH/foUOH4sYbb4w4vir61mw2Y/z48ahduzYQdp0KncfmzZvx1FNP4bzzztOX63nppZcilsHYsmULpkyZAgQ/Ozk5ORg+fHipNjRu4ddSk8mEp59+Wr9evvbaa+jQoQM+++wzZGVlISsrC++99x7atm2rr6Xes2dPvf1QwbqcCn/Y9Tn8NRERERGd2So+fYPOTWXfYUlnD0EzTmE6Q/3111/o1q1bxPqg8fHxWLFiBc4//3wUFxejb9++WLZsGS699FLMmzevSmdPyrKM22+/XV+OYcCAAZgxY0bUh0MZrV69Gpdffrm+1uenn36KO+64AwDw+++/o3v37lHXki3PXXfdhYkTJ6JHjx7Yvn07kpKSsGLFCrRu3dqYtUx+vx933nmnftv98OHDMXnyZAiCgAkTJuDxxx83HlJh9evXx2+//abPZBw0aBCmT58Oh8OB5cuX60+TD9m2bRtuuOGGMteRTEtLw7Bhw/Diiy8CwRmaoWBMyPHjx3H//ffrfVSWxo0b46uvvioVRDrZOYZ899136NevHxB2HitXrsRll12GZs2aYdmyZRG3YldUeLlz586N+rC3aJ+DijK2mSzLmDBhgj4TvSydO3fG559/XuYt/YcOHcIdd9xR7nqm9erVw9y5cyPa/O+oS8+ePfHZZ5/pga1wmqbh008/xZAhQ8r8AgbBGbDffPMNWrZsaUyqcBmNGzfGF198ETF72CjaOKusUymjoKAAvXv31tfqrYy77roL06ZN0/+vaRrmz5+P22+/HYWFhRF5w5nNZrzxxht44IEHos4OLigowJAhQ8p9+GOfPn0wc+bMqF9wVbRfyutbVOA6hWBdXn31VQwfPjxiCYtx48bpwf/KMF5LEVyz+bbbbos6OzzcLbfcgvfeey9qm5RXl1P5teD9P/Zg9NL/AQBeuawtbqpfsVniREQUqcR1AHuzAw/crGoZNXshKSH6z7iqUuzaj6zswN1sJsmO5g2HGrPQGe5I7m84mrcmYl+8MxNFJVmw29Lh9uRAkqwwmeLg9Z5Y3sliSYamqpD9BXDa66LEfeKOVZslDYrqgew/sYyb01EPJa79sFpToSle+PzF+nFxjnoodu2H01EXJa5AOSbJDr/ihs1aAx7vUQCAKFqgqj44bOnwyQXwKy5YzInwyQUwmxIjzsVmSYVf9cDvD9wFFuesj+KSffp5WCyJ0DRAlgv0fZJoA6AAgghF8ern4LDXgcudrZdtkpyQRCtkpRiqGph4Zbemwa96IctFcNozUOI+CIe9Nry+AihK4I7ZOGc9FJfsh8OWDpcnBzZLKjy+XJglJ2SlRC/fYa0Nt+8oNE0J1NVfAL/fBVEwQdX8er/EOeqj2BVYHsthqwuX5wAs5iQoqg+K4oLDlgGXJ/C3l8OeDpc7cIzsL4LXlwcE+0pWCqEovoj8oiBB1VQ9EGizpEH2F0FRvXp7AIDNWhMe7xEAgMUcmEjmk/P1PrXbasLny4Oi+vSwssOWDrf3MDRNhd1WG27PIYSLdzZAUcleOB2ZKHFl6ftD/RRq39D/A3UogqJ49bxWaxq83uPB8w/UIfReJtEOv+oO66dAfUTRDKs5BW7vYVjMyfDJefr4CuUJtZHZnABBMMEvF0PVfLDbasHtibZcpASbNQ0e72H92EA7afDJBbBaakDVvJDlQphNcZD9xbDb68DtzobDngGX+yAs5oTAWPUXRox9SbJCUbz6v4E6H4PTloESz0F9PIXqEGI2xcEkOYP1TIIoSPD4cvU+EwUTbLZacLkPwmyKBzQNslIMhz0DHs9hCIIJiuqBw5YOszkeHu9R+HwFsFprwOM9jBopHVAzJbDUXjT7D81DYfEeWC2p8Ppy4bDWhst7CE57fZS498Fhqw2X5xDstpowiTaomoYS935YLamQ/UX6Zy40/m22mvB4jkAULYCmAoIASbRB9hdFfgZs6XB5siEKFqiaHNY+gXHbNPOuiPOsqP2H5kNVZHj9BdBUGc0bBibalJSUwGw263dQh+Tmb0CRay9UxYdG9QJ3pJ+Ks2bmbsOGDfHjjz+iW7du+r6ioiL94TfVKXTLbiho2LRpU7zyyisVCuwCQIcOHfSZhQAwYsQIfR3MWCxdulT/w7xLly6VfiCPyWTCoEGD9GDKnDlzsHfvXmO206Jly5ZYt24dJk+ejFatWun7a9asiX//+9/Ytm2bvvYmggEUo5SUFMyaNQu//fYbbrnllohb0M1mMzp27IgZM2Zg06ZNpQK7sVq1ahX8fj86dOhQ5jIdVeH48ePIywv84I2V2WzG2LFjsX37dtx3330Rs+PNZjN69+6NefPmYcWKFWUGdgGgdu3aWLhwIRYvXozevXtH9E2rVq0wefJkbNu2rVSbV0dd9u3bh3//+99l1mXhwoVRA7sIzqa86667kJOTgwkTJkQ8lEsURXTo0AEzZszA+vXrywz+hcrIysrCU089VWYZmzZtKjewi2DAOqRu3boRaRV1KmW4XK6TPgCtogRBQN++fbF//368+uqraNOmjb4OOYJB7qeeegpZWVn417/+FTWwi+AyMjNnzsS8efPQtWtXvQxRFNG1a1fMmzcP8+bNixrERBX1LYLXqfXr12PGjBno2LFjxFjPzMzU6xJtbeK//vor4v+xaNCgAVasWIF58+ahd+/eETOyk5OTccstt+C3337DF198UWablFeXUzH0/EbIf/g65D98HYae38iYTERERERERNXkrJm5S1SRma1ERPT3q6ovboiI/mk4c5f+bpy5y5m7nLnLmbucucuZu0SV8scff6BOnTpo0aJFucspaJqGlStXAsHlOEJPficiIiIiIiIiIvqnYnCX/lYZGRlITk7Gn3/+iVdeeQW5uSe+/Q3366+/YurUqUDwAVrht1UTERERERERERH9EzG4S3+r1NRU3HDDDQCAtWvXol27dnjvvfeQlZWFnJwcrFmzBiNHjkSPHj2Qn58PURTx1FNPVenD8oiIKDbTNu9F7SnfofaU7zBt89+zdjsREREREdE/0Tkf3J0wYQIEQYh5GzRokLFoqgKCIODxxx/H3XffDQDYt28f7r//fjRo0AB16tRBx44dMXnyZMiyjISEBMybNw+9evUyFkNERH+jx37ehMMlHhwu8eCxnzcZk4mIiIiIiKianPPBXTrzOZ1OfPTRR/jtt99wyy23IDk5WU8TRRFt2rTBq6++iv379+Oqq66KOJaIiP5+qnri2azhr4mIiIiIiKh6nfPB3TFjxkDTtJi3adOmGYumKiQIAjp27IiZM2fi+PHjersrioL//e9/GDVqFBISEoyHERHRGWBCjzZIsVmQYrNgQo82xmQiIiIiIjojCcYddBY654O7REREVL0Gt22I3If6I/eh/hjctqExmYiIiIiIiKoJg7tERERERERUpTRNNe6ic4VW+SWYNFT+mAhR3zPavoBoKZU/h1D+4L9Rz+FUyj2JqO9T0X1RRC2vbFqU/NH2naro7RVt36kr+3yj7C8jb9llnB6Vv4YGzjd6+5ancvk1RJ5X9HaKti82lS8xdIQGlNOWlS+34irfF6dO0KL3BBEREdEpycvLM+4iIqIKKDq4CkrxIfhq1NX3FcCLIyhCU6RF5PX5fDCZTBDFk8/X8ate5CnHkeJsivS4VsbkKrO9cA2O+48iA4kQFRluE9As+TJjNjqDrc79AYqgoB6SsR95SIQdtY5mwwQbVL8bmj0BorsIh5McOGK3QIIAmwI0PFYAxWSBRQV8mhtmKR6+1HRoihfmYwdQHOfEXqcEG0yQoaDZkeMwiw6oigeaxQaTIsCrlsAiOuHTPBBEK7YmKnBIicgUkqB5CmAryMeBFDtyzQJMEJHsU1En3wWPpACyF4I9GTviVYiiFbWKvUhzKfDCDYvogKZ4UBgfh4MOMzRoaHr4OCzmePjlIkimOMDvht9sgVUBvPDCJFih+EqwI1GCYjXDb7KjjldAeoEHHskPi2qCorghWOIgeF04nOTAUbOMYpMZNRUrGh4rCNRLVuAVZQhCoD6SpsJuToYHfjQ7kgezYIPqd8HvTMbWOBluEajrATILFXgELyywQvG7odniYZFl/JkswQ0f3GYb6ssW1DleAo9JgVWzYGuiimKTGZqg4aJDRZAs8dB8Lih2JyxeH7Yli9DMDjRGKnB8L6yKCbLqggoNfyXHQ5TMaHysEKrVAbPPh+3JIhSzA02QCo/qwh7lECymODQRakIrzIHgLoRgsmNzEpAiJSEdCTggZ8NvtsOsKMg8VgDFbIZFEeDTPDALNiiqFy5nPP6ye+GSJCQrZjQ9kgfYEiD5ZHhFGVbBCm/aiWugcHgHzKY4bE5U4DJJ0ABciDrQcvfAqlngE3ywalb4lBKYpcCYKnLGIStOggAgzi+iQW4+3HEJ+NPuRaZYE1YN2KUdgVsUUMsroFGBH14xUI5fcQHWeIieEvjNIsyCDUJxLnyJNfFnQmB8NckthFWzQJaLIFkToPndUExmaLIbojkOWxMBUbJCAJAIOwrgRoJfQ92j+fCbRFgEK3z+EpjM8YDfDdVig6RJQHEufFYJmjURO+NViJINjXMLYFVNUPwewBYHyG7kxzlx3GFFYt4x1PAAismMralmFItAY68ZtQvccFtNsPk0+OCFRbDB7y+GYHECshdIqAXFHgdN8cJybD9ERYUSlwLZk4fd8RLMJidqFxQgXrVC87sDHWFPgia7cDzOiVyrhqZHCiDJPmhmO0RBgl9QIUKCKqiQfB74rDZYYMHBJBFOewYyEzqcuNAEaYoXBTv+C8URjyy7D26TCU3zSmDxS5A1NwTJjm1JgBlmNDqaD4tfAQAItkQcSDQjHy4Umy1IgBVND+fCJDmgKl54HfHY7pTRQIlD8rEj8MclwuzyYFsyIFrioWgK/KoH8W4ZaSVuOMxJEGQPNHscBE8xSuLi4XHGo0nypcZTLpemeFG487/YGafAYxZgkexIkhLQLPkylJSUwGw2w2KxRBxTsGceUHwUii0e8fZ6MNfrGJFeUSf/TYCIiIiIiIiqnSO+AWyyipopHfUtLaENCiQtYl/NlI5IdF6AtKT2pfZH2+qkdYdotkEUTMa3rFIpjgY4gELUS+6M2imdcMy925iFznBOewYOoBB1UjriAAqRFn8eEuNawmGrDbtmR1Ly+bC6XKiZ2gn5ogyXJMEZ3wh2r4KU1ItgKSlGUnxLOO11UDOlI2rV6A67T0Bq0gU4hCLUcrZAsSQgOaElnI56sMsSklLawVJUgJSk8+Gw1kJyUlskIA6SvTYSHJmBcVzjEpiL81ArtRPyBB+KJAE1ktvBXJSHtNQOiC/2IMXZBLkScAAFSE/tBLOrGMmJreG0Z8DmE5GWdCEOoAAekxlJ8S3hsNZGvKUO4uIbw+b2ITmtA8yuEiQntoE1/xgcHh/8ooQck4rjggc1ktvBVJSP1JSLYYcTceZaSEw6D1a3BzXSOiHHpEKTLLDHNYDdJyAptR0sJUVISW6n1yfbrCE1rhmKJQFJCa3gsNWBA3FITWwFt9mKIniRmtgapuIiJCddALuUDKcQj6TktjAV5qF2cnscMgMuQUVa8oWB+qd1hE21IjmhFfLhAiQr4q0ZiI9vCrvXj9S0jrCUFKNWcjv4zHbUTOmI1LTOsMoqEpyNEVfkgsmSBJuzHuyyiOTUi2AuzEOt5IvgNVlRM6UjMlK744gZOCrJqJnSETXSOiGuyIVEawb8lgQkxzUL5EvpjL3IC7y3R0Zy6sXBNm0Fpz0DDr8ZaUnn45DkhyKZITpqwiHEISG5NcxFuUitdQmsvshrYKKzCZy2OkiMb4Z8eCCZ4gJ1SLkINsWE1NT2sKqmQD5nA9hlCTWSA+PNbElBQkIz2HwaaiZfhKOiD3WTOyIjtRt8ZicK4UVKwnkwF+UFyhPj4ZRSkZTcBtaSEqTU7AJL3hFYRDuS45oizyTgiOBGasrFsKkWxFvSEZ/QHHaXFylpHQJtYkmHzxqHo6IH+aKK2gltkS+qSEk6H9aSYqTW6gqrT0WCrR7i4xvD5pGRnHYxLMdzYJMSEFfkRmp8S+SagCOCG2kp7WEXnHCaUpGY2AoOlw81U9vjL+TB4WwAh+CE0+XBEUlGMbyom9IVluIC1EjrDKvXh5TUi2DTbIi31kFi6FxTL9Y/n7biYlhkIDHlQiT5zSgS/fjLXILk+Bawm5JhdXvgMKciPqEFbMVupKa0Q47JD6eUArPPD7uzPkwlhbDLIhxxDWArLIDZ40VKra6w+mQ47XXLXNVXkKywaxYkJrRA7aSLkS+qSKvRBRavjOTEVki01IbXEo99UgniEQeHLR0OezpsmhXpqV1w0KwCJht8JjsS45rB6cyE3S8hLekCHBW9qJfWA5aSYqTW6gZzSQFqJV+EvciD014XXkscjppkpBR5ER/fFDaPjMSkNrAVu5CScgFy3XuMp3tSgmSFTYxDgiMT2WYNDnu9k5bjcNSD1ZIGW/4xQAsEr08Fg7tEREQUk4837UWNN+eixptz8fGmvcZkIiIiotNCKDOMdGqkKrzP2XSSL1fEsyI8U7Xt+3cSBcm466QqOr6kUyjbSBAq9l5VSpAgVnTMi+WP56pSFW0ZTXWV+3c5G64eREREdAZ7bOkfOOb24pjbi8eW/mFMJiIiIiIiomrC4C4RERHFJHxmwd8yy4CIiIj+MQIPKaro9MLqcfoeXXS63odCqqxvq6gYAGU+9I2qztnewgzuEhERUUwm9miLmg4bajpsmNijrTGZiIiIiIiIqgmDu0RERBSTu9s0wOER1+LwiGtxd5sGxmQiIiIiIiKqJgzuEhEREREREREREZ2FGNwlIiIiIiIiIiI6Jf+UZ06c7SvTnrsY3CUiIiIiIiIiIiI6CzG4S0RERDH54I89SH5jDpLfmIMP/thjTCYiIiIiOodxRiv9vRjcJSIiopg8vmwT8r0y8r0yHl+2yZhMRERERERE1YTBXSIiIoqJJJxYZyz8NREREZ3DNEALzVg85YmLp3xglTr9Z8Hfl06rau5grZrLJzoZBneJiIgoJq9efj7S42xIj7Ph1cvPNyYTERERERFRNWFwl4iIiGJyZ6tMZA+/FtnDr8WdrTKNyURERERERFRNGNwlIiIiIiIiIiIiOgsxuEtERERERERE/0BcLJWIzn4M7hIRERERERERERGdhRjcJSIiophM3bgHCZO+RcKkbzF14x5jMhEREREREVUTBneJiIgoJk8s24Qinx9FPj+eWLbJmExERERERETVhMFdIiIiiolZEqK+JiIiIiIiourF4C4RERHF5LXLLkDdeDvqxtvx2mUXGJOJiIiIiIiomjC4S0RERDG5vVV97P9XX+z/V1/c3qq+MZmIiIiIiIiqCYO7RERERERERERERGchBneJiIiIiIiIiIiIzkIM7hIRERERERERERGdhRjcJSIiopi8s2E3nK9/A+fr3+CdDbuNyURERERERFRNGNwlIiKimDy1fDNcsgKXrOCp5ZuNyURERERERFRNGNwlIiKimFilE79OhL8mIiIiIiKi6sW/wIiIiCgmk3pegAaJTjRIdGJSzwuMyURERERERFRNGNwlIiKimAxsWQ9/3X81/rr/agxsWc+YTERERERERNWEwV0iIiIiIiIiIiKisxCDu0RERERERERERERnIQZ3iYiIKGaqpkHVNONuIiIiIiIiqkYM7hIREVFM3lq/C47XvoXjtW/x1vpdxmQiIiIiIiKqJgzuEhERUUzGLt8Mr6LAqygYu3yzMZmIiIiIiIiqCYO7REREFBOH2RT1NREREREREVUvBneJiIgoJpN6no/GSXFonBSHST3PNyYTERERERFRNWFwl4iIiGJyc4t62HXfVdh131W4uUU9YzIRERERERFVEwZ3iYiIqEoJggBRFLlx48aNWyU3QfUB0ErvB0rvE4RKXW8VTYai+Urtr8pNhgyTYDqxT5BK5eF2Zm+KoMAsmCGKIqyiFbIgA6ofULyBMaf6IEgmCMExqECBAj8EKdDXEERoshvwe/QyBQEQRQFmwQxZkKFpGjTZDc3vhhD8vQGiCEGVAcULUQu8nx8yVEHRy4FoDoz5YBhDEAQIohlQfBBMNoiaH6IgwCpa9XMRVD+geAAhkN8m2KBAgaD5ISgeCKoMUfNDECUIAgAIEFQ/BIsTEE3wC4BDckBCsH6iST9PqDKg+ABBghB8X78mB8oPvh9EKZBH8cIPPxyiAwIEaNCC5xZoVyg+aFAhCiIEQYQgSRBUOdA+ABBsd1mTYRWtAAC/4IcgmQPnEizfLJihQQuUq8qAICHQARJkTYZfkwN9ErzWQPFCsMbBBy988AauP4IAQTJDDtYl1P42wQYRJ8aKYLIF3huAH36IoggZPlgFKxRBgSCaIIoCREEEVD8E1QtoKgTFF+gHTYGsyQAAUZUByRwYX0Lk9U5QvYDigVfzwCSYoQTfK9CugU3zeyGo3kDfBdveLJrh03zwww9BFCCIoj62RVGEoikQBRGiIACiSe8nUVMCY0oyA6oPotkePH8ZAoLjT/FBU7yBfKoMUTIF+s1sBxQPFCiQguNCFAPvcaK/AnUM9JEfECUIggjRbIemeANjT/EF2jr8vYLjTZBM+niD4oGm+ABBgEWwBM4vNE4DAzpw3sFN1Pz68aF2gGQJ9IEoQlMVKJIJdsEarK8GQTRDC9ZfNFkCx0GEoAXOHaocfD9A87sCbSCagmMMUOGHrLoj+jR8gwAIqg+y5gt+FrzBMR+8Dmh+WAUrBACa3w1N8UFTAnmtghWyGuhjKJ7AtQcaRNUHs2gOlC2ZA30jBj4/DskJH3xQoMAsWCCYbYHrgSAGxrMUXGKuEj9fwzdN8UFWfXCIjsD1K/hzUCjjZzZUGfAH+l0Lu25WdhM0TdMCZ05ERERERER/F1X2QFP9kKxx+j5ZlQN/pBrk5OQgKSkJdrvdmFQmv+KDKfiHfHXwql5YRSs0aNA0FaIgGbPQWUKDBgECFC0QpFJlD0SzDZoiQ5DM8PndsJgCY8+n+mARLXoeBMdy6DUAKIoPkmSBW3HBLjkC7+H3QjBZofq9EE2BgGXoOFV2QzTb9bI1VYEgSvr/AcCvyjCJZqiyC6LZEQj4BMe37HfDHDw/TfZAMNvgV7wwSYH3iXiv4Psbz1lT/RDEQKBH1RRomgJJtEDxlgACIFmc+nt6ZResZgd8qg+apsEqGesVqA8AeBQPbFKonQL7Q+0KALLfA7MpMl1VfBAlCzyKGzbJDln1QYCg11/TNEgWJwDApbjgkE60R3jZbsUNu2SH6nMH6hO81miqHz5NCZ63D6Ip0I6h/Jqm6cFCAGFtFjg/v+aHSTDBq3hhlaz6dSvUxwjrB1WRIQbPx6t6YBVPjCtjH4SE7w/VT/GVAFogdg0IEC0OPV9ovKmaCkXzwyxa9P4MjekQWfHCHBwXqs8F0eI4cT7B/6uyG4IgQQi2i+orgYbAGAi1l+r3QDTZoPk9gcB3UOj9Qm0UKhNhnwH9sxB6v7A+UH0uaAj0b+i8/KoPJtECjycfZk2AZE/U3y9Un4jxp9crOCaCn6cQTVUAaPp4h6ZCBQC/NzA+VQVQFQgmS0R7AYDiK9HHnt9dAJM9EdCUQB2CY15WPDAHx3w0miLDCwU2KbL9Iq4pmgoEzwOaAsVkgVk0w624ASAwrkOf6eAYU1QFkihFfAYQ9rNK9XuhSSaIqhLxWTFeKyorVI5HccEWvN4dPXoUNpsN8fHxxux6ncOvE5XF4C4RERHFTFYDv06YxRO/+BMRUfU5leAuERERnX7lBXerApdlICIiophM/n0X7K9+DfurX2Py77uMyURERERERFRNGNwlIiKimDy9YjMUTYOiaXh6xWZjMhEREREREVUTBneJiIgoJnGW4PpchtdERERE13XHCQAAaVlJREFURERUvRjcJSIiopj8p+eFaJYSj2Yp8fhPzwuNyURERERERFRN+EA1IiIiIiKiswwfqEZERHR24APViIiIiIiIiIiIiKgUztwlIiIiIqJKUxQfDh/bAEmyRuwvLslBnDMdPl8RFFWG3ZYSkU5VIz8/Hw6HAxaLxZhEREREVSgxPhM2a7Jxd4VV98xdBneJiIgoZm6/AgCwmyRjEhGdw37fNBluz/GIfXGOdBS7ciCKJqiqAoB/bhAREdHZq3O7JyBJp/5lanUHd7ksAxEREcXkjXU74HjtGzhe+wZvrNthTCaic5gglP2FjiDwTw0iIiKi6sbfuIiIiCgmz6zYEvU1ERERERERVS8Gd4mIiCgmCVZz1NdERERERERUvRjcJSIiopi8ecWFaJkaj5ap8XjziguNyURERERERFRNGNwlIiKimFzfLANbh/TB1iF9cH2zDGMyERERERERVRNB0zQ+vpaIiIiIiCpt/ea34XIfjdgX50hHsSsHkmSBosgA+OcGxW7G9JV4d8oi4+4y1aqdiPc+GgIAuG/wB0ivk4yJr98Gu+PUn3b+d8s9VoT7Bn8AAHjvoyFITauep67Hwu3y4bF/f47tWw9iytR70LxFHSCs/+5/8ArcftclxsNKCeV/+bVb0bVbc2PySWmahl9X7cSxo0Xod/1FxuRzVkU+J/HxdnTo1Bh3D7kUDRrWMCZXqxfGfYufl2yNGBtEZ4PO7Z6AJJ36z4+jR4/CZrMhPr56rtucuUtERERERERnNGecDTVrJUZs8fF2AIAoCqhRIyEirVbtRIgS/9z9p1r96y6M+ffnKCpyG5P+EWqnJ+Ha6y4qtV1+RSto0LD4p814YMiH2LbloPFQIjoLnVU/7dxuNyZOnIhFi8r+JmrHjh3o27cvLBYLBEGAIAho0qQJtm7dasxaKZqmYeLEiXqZzZo1Q1ZWljFbVLIsY9iwYfqxEydORGjC9O+//w6n0wlBEDBo0CDjoTH59ddfYbVa9fedOXOmMUuZBg0apB93si0lJQWdOnXChx9+iMLCQmNRQFh5TqcTv//+e0RaeBtUZJMkCQ0bNsT999+PrVu36m0ZjaZpWLJkCcaPH29MKtP+/fvxzDPPoG3btpAkKaKeffr0wXfffQdZlo2HnTVyc3Pxn//8B506dYr4nMTFxaFTp0747LPPUFJSYjxMl5OTg8zMTAhRxmx1jueqVlhYiLFjx+KPP/4wJpUyY8YMCIKAa665Bh6Px5h8UmvXroXD4Sg1lqNtmZmZyMnJMRZRpoKCAlxxxRUQBAE9evRAcXGxMUuEwsJCTJw4EU2aNIno+4EDB2L16tXlfp5C3G43pk+fHvEZkSQJbdu2xfTp0+F2l/9LdFW1h6IomD9/Pi655JKI87jkkkswf/58KIpiPOSkNE3DSy+9BEEQMGHCBGNyhVW2X0J16datm/65lCQJHTt2xOeff15mm1b2+hnaol2LETyPU1Xs86PY5zfuJiIiqhLX3XAxvv7ukYjtqXHXAQDaXpCJGV8+GJH21tTBSE52Gouhs8Ttd12CFWvGndKsXQBQFNW46x/lggsz8diT15baxr84AHO+H4W+/dqhuMiDD95bCq/39P1t+9S46/HT8qc4a5eoip01wd1t27ahXbt2GDNmTJl/5G7YsAEdO3bE/PnzI4JvVqsVNWvWjMhbWYIgYMSIEejbty8AYOfOnRg3btxJg3yapmHSpEl4//33AQADBgzAI488AkEQjFmrlKZpmDFjBnw+n77vk08+KbPtYpGXl4fVq1djyJAhqFevHhYsWGDMUqVUVcXevXvx3nvvoVWrVhg9enTUfigpKcHgwYPRs2dP7Nmzx5hcyo4dO3DllVeifv36ePbZZ7Fp0yao6olfCvLy8rBw4UL069cPLVu2xJIlSyoUCDtTHDp0CPfccw/S0tLw8MMPY/Xq1RHtVlJSgtWrV+POO+9EnTp18MUXX8QU6DmTLVmyBI0bN8akSZPg95cfjNI0Tf9CqXv37rDZbMYsJ7Vjx45q+expmoZXX30VixcvNiZFFar3mDFjsHv3bn1/SUkJZs2ahU6dOmHw4MHlBvcPHDiA7t27Y9CgQRGfEVVVsWnTJgwaNAjt2rXDtm3bjIfqqqI9CgoKMHDgQPTt2xerVq2KOI9Vq1ahb9++6Nu3LwoKCoyHluubb77B2LFjjbsrpbL9cujQIfTu3Rt9+/bFypUr9c+lqqpYs2YNbr/99pO2aVUIncepeH3tDsRP+hbxk77F62t3GJOJiIiI6AxhtZox8I7OSEiwY/vWgziUk2/MQkRnmbMmuDt37lxs377duDvCxx9/jPz8wIXp5ptvxo4dO5CdnY25c+ciKSnJmL3S7HY7Jk+ejEaNGgHBYOmsWbOM2SKsXbsWL730EgCgadOmeOWVV2A2m43ZqtyRI0f0wEKDBg0AAEuXLsXGjRsNOU+uV69eGDp0aJlbhw4dIIqBoVRYWIjbbrsNa9asMRZTITfddBOys7PL3Pbv34+5c+eiV69e+jGvvfYa3nzzzYhyAGD79u2YPXu2cXdU8+fPx4UXXoiffvoJANC4cWO88MIL2LRpE7Kzs7F37158+umnaN++PQBg9+7d6NWrF956662zIsC7ceNGdOjQAZ988gkAoEaNGnjsscewevXqiHbt06cPENaPY8aMiRo4P9tNmzYNx44dM+6OKjc3F2vWrIHJZELXrl2NyRUSunPAYrHgtttuK/UZCt9uvfVW2O2BWwxP5rvvvsOLL75o3B3VmjVrcOONN+r1bt++PRYsWIDs7GysWrUK/fr1A4LXtf79+0cNihYUFODuu+/GunXrAEMZy5cv1z+X27dvxw033IADBw4YSgiItT1kWcbo0aPx1VdfAcHzWL58ObKzs7FgwQL9c/rDDz9g6NChFRrDiqLgP//5D26++eaIL3VORWX6paCgAHfccYd+vU5ISMCUKVOwd+9e7NixA//3f/8Hs9mM7du3o0uXLqWurampqRgyZEipdou2de7cWT+uc+fOaNKkif5/43lU1riVW6K+JiIiOpMcOVKACS/MxWVdnkO3DuNw/TWv4evZa+D3n5jQkHusCDf1m4Tnx32LxT9tRp/LXkK3DuPwyIPTcfx44E4cVVXx66oduH/wB+jWYRy6dRiHAf3fwKKFkZNDQrxeGV/PXo2BN7yp57+sy3N47JHPsfevyDWrEfyieP26v3DXwLfRrcM4XNppPN58/Qe43Scm7oT8uT0bvbq/gBnTV2LTH/sw7O739WMee2RG1PIB4OiRQkx4YS56XfriSc8HYXUY0P8NvQ53DXwbSxZtiVrnWM2YvhLdOozDqhV/6vtUVcWSRVv0dunWYRx6XfoiXnx2Dg4eOA4E1/wdcf8neHzUFwCAd6csQrcO4zBj+kq9HK9Xxg/f/4HbB0zRyxl8x7tYv+6vUn/bzZi+Er26v4BtWw5i0cJNev1D7bt/X25EflRyfIT6b9pHy/HFp6twWZfncGmn8Xjx2TnwROnvqhIXZ4PdYYHP54fHE/m7cu6xIkx5YyGu6vmyPjbGjf0KR49Ev0u3pMSLdyb/pOe/qufL+OLTVdj0xz706v4CXhj3rZ73hXHfolf3F/Dn9uyIMiozvl4Y9y1u6jcJ2Qfz8MWnq9Cv9yv6eU54YS6O55Z/xxzRueisCe6eTElJiR4wyMjIwGuvvYamTZsiPT0djRs3hslkMh5ySho0aICJEyfqwcwRI0Zgw4YNxmwAgKysLNxxxx3Iz8+HxWLBhx9+iMzMTGO2arF06VJs374dJpMJI0aMgMPhgM/nwxdfBH7IVcaIESMwderUMrfVq1cjKysL3bp1AwDk5+fjueeei5iZN23aNGiahpKSElx0UdkL2jscDqSnp5e51a1bF9deey0WLlwYccv0u+++i/3790eUVVFff/01+vXrB5fLBbPZjA8//BB//vknnnzySbRu3Rrp6enIzMzEHXfcgdWrV2P27Nkwm81QVRUPPfRQucuEnAnWrFmDyy67TG+f8ePHIysrCxMmTECHDh0i2nXBggX45ZdfkJaWBgQD5x9++KGhxLJddNFFKCkpgaZpmDZtmjH5rLRx40Zs27YNzZo1Q4sWLYzJJ+XxePQvVdq2bYspU6aU+gyFby+//HKFvozKysrCo48+WuqXnWhyc3Px4IMP6l9+PfPMM1i1ahX69OmD9PR0dOnSBXPmzMHHH38MURSxePFifPBB4GEd4T799FM9+Dd06NCIMrp164b58+dj6NChQDDAG62MqmiP7777Ti87dB7dunVDeno6+vTpg2XLluHee+8FAHz11Vf47rvvIo43OnDgAAYMGICHH364Qu1Znsr0i3GGb7du3fDnn39i+PDhyMzMRNOmTTF+/Hj88ccfaNSoEfLz8zF27NiIJR4aNGiA//znP6Xazbg9+eSTOHz4MBD8ovHDDz9EYmKiXs4XX3xxyoFdAEiynvjSMvw1ERHRmWL71oO454738POSrejeowUu6d4cx3OL8car3+PtN38qFdRbtmQrnn/mWzRtno7Lr2iFOhnJSEx0wO9X8M7kRXjskc+xbetBXNK9OfpcfT48Hhnjn/4ar748PyJYXFLswZh/f4E3Xg3c3XhNvwtxTb8LUScjGb+u2oEHhnyInX+eWIJK0zR88dkveOhf05C19ygu6d4cPS4/D3O//R2PPjQDBfkuPW+4Pzbsw2P//hw5Ofnoc/X5aN22Hn5dtRP33P4uli2NvPtn3Zo9uOvWtzHvv+tRr14Krr3uomD+HRh069tYtiRyScPjucUYNeIzvPHqAng8MvpcfT4uv6IVsg/m4ZknvyxV5+qgaRpmzvgVzzz5JfLyStDn6vNxTb8LUbNmAhbM24j7Bn+Av/YcgWQS0aVrU7Tv2BgA0LxFHVx73UVo3KQWEOyP58d9ixfGfYvDhwpw+RWtcPkVrbAvKxcP/Wsavvjsl1JjQdMCQeJn/+8b1E5PwjX9LkR6nWT8umonHhjyIf7ac0TPW9nxETJrxq94963F6NSlCbpc0gz16qXCZrfoQe7wACnCgt/G/RW1a+dhHDlciNrpSahd+8Tv2jt3HMKwez7ArM9/RWKiA9f0uxCt29bD4h83465b3y61Ru+hnHw8cO+H+PzTVXr+9PQkvD35J0x44TvIcum6Gp3K+Cop8eLFZ+fg3bcWoVWbuuhz9fmIi7dh3n/X4+Hh05GXV/adiETnonMmuKtpmn57dcOGDSP+aK1q/fv3x7Bhw4BgIHP06NGlZrnJsoxx48Zh586dAIDnnnsOl1xy8qeCVgWPx4NPP/0UANCsWTPcdNNNeuD1p59+0v/Ar0p169bF22+/rQdhfvnlF+zdu9eYrcqElskI3UK8a9cubN682ZjtpA4cOICxY8dCVVWIoogvvvgCgwcPhiRJxqxA8H0HDBigzxRWVRVPPfUUcnNLf2N7JiguLsbYsWP1oN6ECRPw9NNPl5oJGa5z5874/PPP9S8wnn32WezatcuY7R9j6dKl0DQNPXv21IPelVFQUKDfddC6detSgcpTEZq5unPnTjRu3BgOh8OYJcKqVauwdu1aAEDfvn0xZsyYUncQCIKA22+/HTfeeCMAYNKkSfjrr7/0dL/fj59//hkAkJSUhJEjR5Yqw2w244knnkBGRgYAYPHixSgqKorIE2t7uN1uTJ06FZqmIT09HY899lip87Db7XjmmWfQpEkTaJqGqVOnRl0GorCwEM8++ywaNWqEb789tV+Mw1W2X7Kzs/VrdVpaGt555x3Url3bmA0tW7bU1w3/6aefMG/ePGOWchUUFGDIkCHYs2dP1C8ai4uL9TXZQ31XWVOubIc2NRLRpkYiplzZzphMRET0t/N4ZHS/tAW+nfdvjH9xAF569Va888G9sNnM+Hnx1lK3pns8MgYP64E33xmE8S8OwOgnroUkiVi1Ygdmff4r6mQkY/rMf+GlV2/FU+Oux8yvR+Ci9o3w3ZzfMW/uick/P3z/P/y+dg+uve4i/H979x3W1PXGAfx7gbBH2FuQpag4UFHBWRfuietX0VaxWotb696zWltntdY6a92KUvdEwI0LFQWRIUNA9g6Q3x9JrslNgkyr9v08z32M95ysc08uyXvPec/+wxMxe34/zJ7fD/sPT8S47zsjN6cQV6WCqS9fJGHXjmswtzDA7gMTsGrdcCxZ6YPj/0yDubmB3ChLiZshL+HZ1gVHTk7GvMUDsHn7N1i8YjDKysqwc/tVNtiVmpKNX9edQX5eERavGIyd+77DrLl9sGnbaPyy2RfaOhpYuzqQDViWlpZh145rePQwFr37uuPoqSmYt3gAlqz0weGTk9GkqZ3ce64NSYmZOHTgJlwbWGP/4YmYt3gA246jvm2PslIhnoa/gbq6GoaP9MKgIR4AgE5dGmDW3D5o7ekMADh5/B6uXX6GJk3tcPjkZCxZ6YMlK32w//BE2NgaY/uWS7h3RzalX1GRAC9fJmHrjm+xadto9nl79G6KrMx8hAa/T0lV2f4hkZtbiIXLBmLVuuFYtW44Rn4j+v1e0wryi3Hl0lMsW3QcQqEQvfo2gwFf9J21IL8YG9efQ8rbLHzj15Htr5u2jcby1UOQn1eEVcsC2L5UWlqGPTuD8Do6RaZ/79z3HRavGIz4uLQP5j6uav/KzSnEu7Rc/HXEX6aNmzS1w+voFDx9ongGISFfqi8muCtNsrBObVFTU8Py5cvZqb+XL1/G1q1b2St8QnGeXckU+I+VZ1ciKioKoaGhAAAPDw/Y2tqyuYIjIiJw9epVzj1qhpOTEzw9PQEA6enptR4Q1NLSQvv27dn/P378WKa8Iv744w820OTv74+BAwdyqyg0fPhwtGrVChCn3ggJCeFW+SQEBgayqSZ69+4Nf3//CvXDTp06YciQIYB4AbVjx45xq/wn5Obmsse2S5cu3OIKiYuLYxcEa926dYXavzyS88uRI0fA5/Oxbt26Dwadb9++zd728/NTGtzn8XgYNmwYACAhIQFPnjxhywoLC9mUDvr6+jA2NmbLpJmamrLT/ePi4uQWEqtue8TExLCBak9PTzbtDJetrS169OgBiN+/dI5hCX9/fyxatIhN21C/fn3s2bPng0FZRapyXF6/fo2EBNHoh969e6NBgwbcKqyOHTuygdfz58/LjShRRigUYvv27eyo3FmzZsldaMzLy0Pjxo3h6OhY5dQjfZ2s8Pjbbnj8bTf0daIFMgghhHx6dPU04ftte2hqqbP7nFws0Ky5PXJyCpCdLXshWF9fC+071pf5rlJUJMCpE/chFAoxbkJn2NmbsmU6upqYNM0bunqaOBv4EHm5hSgsFODpk3gYGeui38DmUFN7P4CEYRg4OIrWhUl9+366+/Wrz1FYKMDwr71Q1+H9ujH6+lrw+/4r8HiKB6GYmupj7Hed2PfHMAw6ftUAnbo0lAl2hQa/RGxMGnr2boaOXzWQeX8tPBzgM6w1sjLzcfG86Hvgm/h3uHLpKWxsjfCNXwdoSM3QMTLWxXj/LuDxVPFPQBhycuQvpteUnJwC5OcVQU9fC5qa718DwzAYO/4rnLn8I3r3Lf8Cc3p6Ls4GPgSPp4rx/l1gZKzLlllY8jFlRg8IhcA/px/IBSXbdaiPhm427P/V1FTRoZMrACAmWpTKorL9Q5qdvQlatBSlf5QmWVhu3uIBFdovce7MIza9gfTWreNKLJp7BIUFAkyZ0RNDhrdm7/P4URwePYhFg0Y2GDq8tUx/bd/JFf0HtcTr6BSE3RUNAHmbnIXQkJewszfB2PGd2PqSvtezdzP2/spUp3/17tcMNrZG7P91dDXRpq0oiB8bozi9CCFfqk8+uLtmzRowDIPZs2ez+/r27QtGvIr6+fPnoaOjAz09PVy/fh0AcP36dejp6bF1lK20Xh3GxsbYvHkzO+ps8eLFCA4W5fEJDg7GggULgI+cZ1fi6NGjyMzMBMMwGDp0KBiGQefOnWFkJDrx1dbCahoaGkpH4Y0aNQpMOSu01yTJ6vEtWrRAfr5o2tLevXvBiFeJl0zRTktLw5EjRwBxOoiRI0dWONBkYGAAX19f2NnZYdKkSTKj7aRXrz99+jQKCgqwadMm2Nvbs6/B3Nwc06dPV9o3k5KSYGdnB4ZhMGrUKG6xDEnbcvu69AhufCCox6WmpoZRo0bB1taWXSCrIourSb/38l53dnY2fv75ZzRu3Ji9GKOqqopWrVrhwIEDSvsnt21LS0vxzz//oG3btuzjqKurw9vbG0FBQXIBMElb7d27FwCQn5+PFi1agGEYdOzYUS4QGR8fjydPnsDCwgKNGjWSKasoyeJhampqaNiwIbe40qTzeK9atYq9yFAeSe5bbW3tD47MtLOzY4Ob0kFhNTU16OiIVpwuLi5WeowKCwuRmir6MmVmZia3AF112yMqKgrp6aKcal27di035U7Hjh0B8QyLR48ecYtZPB4Pa9asQVhYWJVeE6p4XDIyMtgZJ/Xry/545OLz+WzQPDw8nB2N/yEPHz5kX1fDhg0VXuAxNzfHxo0bERUV9cE88oQQQsjnSl9fC3xD2Qu46upqMDDQhkBQiqJC2UV2jYx1weeLvvtIvEvLRVRkMkxN9dG4aR2ZMgCwsDSAg6MZYqJTkZSUCU1NHhYuG4SAszPg7GKBjIw8PLgfgxNH72LOzINYPF92AEVhQTGehSdAQ4OHBo3kv7PZ2hrD2uZ9MEtaU3c7WFjK/hZTVVWBZ1sXAMCzp6LvgxHPRLlOW3s5Q1VVPhzQxssZGho8PAtPQGFBMZISM5GdXQDXBtYwNdPnVoeDoxlcG1ojLvYd3ibLr9lQUyws+bC2McKdW1HwG7UDAcfvITkpU+77fnnexKXjTXw6XBtas4F1aQ6OZjAx0cPzpwnIypJNf+HoZC73HcrEVE8m0FzZ/iHN0ooPLe33Fx6qy8KSjz79m7Obs4vo96q+vhaWrx6Cc1dnY9CQ92vnAMCjh7EQCoVo264edHRlv8MzDAP3FnUBcVoPAIiLTUP6u1y4NrCGkdH7QDnEfa+1lyjQWp7q9C9Jqg1p9nXfB9QJ+S9hP8lPnz6FqakpG3xStjk7O1c5t+mXxsPDg11Vvbi4GNOnT8ejR48wbtw4FBcXK5z+WtukA5aurq5o0aIFIA4yd+vWDajGwmofkpmZiZcvRVNStLS0FE4vrkkFBQUICgpi/9+4cWOZ8g+JiIhgX2/Lli1Rr149bpVyff/994iJicGGDRvg4SGa9sOVmZmJgQMHYtKkSYiNjWX3p6SkYP369XBycsI///wjc5+akpiYiLCwMEAcsCsv17Ei3t7eiIuLw+7du9G1a1elqSoq68qVK3B0dMSMGTPw5Mn7RQXKyspw584d/O9//4O7uzueP5fNDcaVmZmJYcOGoXfv3ggJCWEfRyAQ4Pz58+jQoQO+//77Ci2mpUxQUBDS09Ph7u4OK6uqjUaUBEgNDQ0RFhYGb29vGBkZsedUc3NzTJw4USYFgjLSuXNHjx7N5pWtLTExMewXZk1NTXYkbHJyMg4dOqTwy/S5c+fY/OdeXl5yF3yq2x6SxwaAOnXkvzRLs7W1ZS9oSH/+JPT09DBr1iykpaVh1qxZFb74wfWxj0tSUhLy8j6cR0wgEGDVqlXsxb6lS5fCzEz+hwwhhBDyX2BmbgAtTcXBs9LSMrlRgXxDHairy15EloweLSgsxo7fruCnladlts2/XkDK22zk5xexgaiSklIc/CsU3TqsQt/uazFpwm6s/+kf3AqJBF88HV5CKBS9Fh1dDRgYyM8kUldXA99QNuAsUdfBTC74CAA6OhqAeHRwcXEJ3r3LgYYGD2bm8oE0ADAw0IaOrgbycgshKCll34eyx9fQ4MHQSBdFRQLk59XeAmAGBtqYvaAfTEz18Do6BetWB8Kn36/o2WUN1q0ORFzshxdLzskpQGlpGZKTsrDpl/Nyx2/HtisoFpQg/V2u3OJh0iNElalK/5AwMNCW62/V0bSZHWbN7cNuO/d9h7kL+yM3txCrV5xC5Itk7l3YEeR3br+Se+0/rTyN82cegWEYvHmTjoL8YsTGiNrc3kFxQNXMXF9mJK4iVe1fmpo8GBop/iwQ8l/EBncbNGiA8ePHy5YqMH36dNja2nJ315qJEyciMTGRDaJCvDhXYmIi7t69iw4dOiA6OhqRkZFo3Vo0paB169aIjIxk69Tmj9lJkybBx8cHEI/c+uqrr9hp/h8zz67EvXv32KCYj48POy1YTU0Nw4cPB8SB6L/++kthYKaqhEIhAgICcO/ePUCc97huXdGVvdpy5swZNuVA/fr14e4umobTuHFjREdH49y5c2ywZvDgwUhMTERiYiK6du0KiC9oSEbNOTk5QVdX9mpjTZgyZQrOnTsHCwsL7Nq1ix0JOnbsWEA8crRv377s+6hJ0dHRbH5lJycnuSDbv+H06dPw9vZmp/ePHDkSISEhSExMREhICEaOHAmIA+/9+vVTGJCTmDJlCo4ePYqmTZvi8OHDiI+Px8uXLzF16lR2pPz27dtlFtPasGEDEhMTMXjwYEB8EeLcuXNITEzEsWPH2JGpEOeYlRyXvn37yo1ArQjphR5TU1MxefJknD9/HhkZGWydlJQUbN26FfXq1ZNJ78IlEAgwZ84c3L17F87Ozli8eHGFZwRIcpCXN+JWIiUlhR3xnpCQIBNEHDlyJDp37gwAmDt3LubMmcP2saysLKxevRrffvstIL6gxE1HUxPtIcljrq2t/cFzu4qKCvv8L168X2lZYvPmzVizZg309RX/uKmI6hwXLS0t9vVlZytefVgiLy+PvbCakZFRodzp165dY1Oq9OrViw3O15bMwmJkFtbejzpCCCHkU5GbU4iz/zzE6ZP35Tbp3L1CoRBbN17Elg0XwDfUxrgJnfH7Lj8EXpyFKyHzMWVGzf1tVlGVD4xVh7qGmsKRvcqoqamCp14zg0GUqe9qhSMBU7B1x7fo0aspdPU0kZtTiIDj9zBy6Ba5heCUSXmbhcCAMLljdzbwodIF6yqjov3jY2IYBt69mmDk6HbIzSnE0oXHlb6Wh2Excq/79Mn7uHE9Qua7OXeRs9r0MfoXIZ8z9mzNMAzGjRvHTvtUpGXLlmwg82PR1dWFpaWlTODN0NAQlpaW7JRfc3NzWFhYQENDdFVSQ0MDFhYWbJ2aGnGoCI/Hw9q1a+HsLJpyIJku/LHz7EIcjNqzZw+EQiHU1dXZxcYkWrVqhfr16wMATp48WSMLnhUWFiI8PBzjxo3DN998w+739/f/YL7JqpA8n5+fHwYPFi0QoKKiguXLl8PcXDQtg8fjwdzcHCYmJmz7a2trw9LSEpaWlmyQTnoRNMnxq2np6elo164dHjx4gNGjR8PGxgaNGjXC77//joMHD0JFRQVlZWVYsGCB3KJ81ZWdnc3+8bW2tpYJXP4bUlJSMGfOHAgEAqioqODgwYPYs2cPPD09YWlpCU9PT+zZs4dtl8jISMycOVPpyNv09HSMGTMGoaGh8PHxgY2NDZydnbF+/Xrs27cPDMNAKBTiwIEDbBCfz+fD0tKSTTvAMAxMTExgaWkJY2Njmc/r27dvce/ePWhra7Mj4CsrOzubDUZCnKt22bJlePLkiVxAWyAQYOLEiVi7dq1cQBMATp06hZ07d0JFRQXr1q2r1IwASS7skpISHD16VOHjQ1xeXn5lAwMDHD9+HP7+/lBRUcGaNWtgYWEBhmHA5/PZ4+vj44OgoCC511gT7aGsPyhibm4OQ0ND7u4aVZ3j4ujoyI4IDwgIQErK+5WWuUJCQpSOZlakoKAAP//8M3uOnDRpUpVHJlfEujsvYbghAIYbArDuzvtFRQghhJAviZ6eFrR1NNC4aR2cvzYHN+4sVrp5tauHpMRMXL4YDr6hDn7d4ouR37SDa0NrGBhoK/ydyDCi6ex5uUVyaQEAQAghO1uNS5L3lSstVbS4ram5PtTV1WBsrIeiIgFSpPL8SktLzUFWZj54PDWoMAzMLUSDBF5Hpyj8DpmfX4S3SZlQVVVRmg+4JqmpqcKtSR3MXdQfZy/PxrHTU9Gnf3OUlQlxYH+oXC5baXp6WlBVVUHX7m4Iur1I7phJtotB81CvfuVn7VW2f3xsDMPAZ3hruDawxpv4d9i68aJMgNZUPJp7+Zqhcq9Xetu0bTS0tNXZFAjK+l7K22wUFZX/3f1T61+EfK5kLsXZ2tpi+vTp0rtYjDjvrbJFdP7L7OzsMHPmTPb/qqqq+O677yo8equmREZG4sKFC4B4QaymTZvKlJubm7MLhiUkJFQ4HYAkx7GiTUtLC25ubvjjjz/Y+n5+flWeliydG1fRxn0+Ho+HXbt2VXghNGnSI/lsbN4nx69J2tra+OWXX+RSVDAMgyFDhmDChAkAgDt37rCLRNUU6SBabb2/yjhx4gSePn0KiEd9DhkyBAznSy3DMBg0aBB7oeCff/5RulCetbU15s2bpzBg1blzZ7i6ihY4SEtLQ2Gh8i95yjx58gRxcXFo0KABHBzkFzeoiJycHBgbG0NHRwcDBw7E69evMX/+fDRq1IgNaO/duxeBgYFswHnt2rUyqQcgzjPr7++PsrIyzJ07F3369JEp/5BOnTqxF3Y2bdqEw4cPy315EgqFOHbsGHbt2iWznys7O7vclACS0bKSgLq0mmqPilJVVa3Vi3vVPS729vbo378/IJ5JMHfuXIUjq6OiojB79my5Y1ae4OBgduR5z549a30WydKQ98dI+jYhhBDyJTE01EEdOxNEv0pBctKHB2ZIpunb1zWFsbGeTJlQKMSdW7ILvmpqqaNx0zooKhLgZsj77/ISyUlZiH6l+GLw82cJSE+XXT+itLQMD8JiwDAM3JuLZlXWbyAKWt4KiZRbNAwA7t2NRmlpGezrmkJTSx2WVnzo62vh+bMEuVQFABAbk4ZXUW9hZq4PU9Oqz4b6kEvnn6BP97X4Y9sVmf1m5gb4ZmwHmJrq411aDgoLlQcTLa35MDXTx8sXScjIUP59tqoq2z/+DQYG2vCb8BVUVBhcu/IMITfeX5Rv0FD0mzHs3usKfe+0sTVi+wa37wmFQoTd+/DAhE+lfxHyuZObZ+Hj44OWLVtyd3+UKZ2fq9jYWKxdu5b9f2lpKZYsWVLjIzE/5PLly+zI4QEDBigMevXu3Rvq6qJcU/v376/R16ilpYWffvoJGzZsqNXAtoqKCtzc3LBu3TqkpaXB19dXLkhYWRWZ4lwVPXr0QJMmTbi7AXEgc9iwYVBTU4NQKMTVq1e5VWpMbb2/ipJOcaCtrY3+/fsrPWZqamro27cvIE5bIVkokatJkyawtLTk7gbE+WFNTUVXkl+/fo2cHNGIhco4c+YMhEIh2rRpU+XRny4uLggLC0Nubi6OHTvGLmrI1bNnT0ybNg0QB6MPHDjAlmVlZWH8+PFISkpC586dMWPGDKVtp4y5uTmWL1/OjhQfNmwYRo0ahdDQUCQlJSE0NBSjRo3CsGHD2JHMily6dAlubm74888/UVZWJpNW4/bt2xg3bhwA4PDhw6hXr57cBaSaaI9PRU0cF4ZhMHPmTHbmwM6dO+Hp6YkjR47gzZs3iIyMxMqVK9G8eXO8evUKbm5u3IdQqKSkhD1GDMNg4sSJCv8e1CQjLamVjaVuE0IIIV8SLW119OnnjtycQmxYdxbZ2bIXZcMfx6NH59X4eshmvIlPZ0dyRjxLwOvo90FZoVCIK5ee4uQx+cEdnTo3gAFfG0cO3kL44/dr3eTlFmLj+nPIzVE8aCE2Jg0nj91jR2IKhUJcu/IMVy89RZNmdmjYSBS482zrAjt7E5wJfIBrV57JBPHu3YnGkYO3oKuniV59RAOFbGyN8VWXhngTn45dO67LjMRMf5eLbZsuQSAoRa++zWDAySFckxyczFFSUoqL55/gTfz7GZgA8PhRHNLScuDgaAY9fdnvPKkp738HmJjooXuPxoiNScOOrVdk3otQKMTli+Ho2GYpJk3YLXdsK6Ky/ePf0qy5Pbr3bAKhUIid26+ygW63JrZwbWCNk8fu4sqlpzJ9o6hIgDXLT6FD6yXYteMaIO4b7Tu6IjYmDccP35EZBRx09bnC/s31qfQvQj53csFdY2NjzJ49W+ZHqra2NubPn1/rPw4/RwUFBfjhhx8QGRkJhmHYqe83btzA7NmzFY5eqw1ZWVnYv38/IB7RKFk8jatp06bo1KkTIB4tevPmTW4VOV27doWfn5/CbebMmfj777/x7Nkz5OTkYObMmdXqJ9K5cSXbkydP4O/vzwaM69ati19//RXTpk2rVq7MBg0asLcrM9W7Mlq2bAk1NeWJ8e3t7dl0Ei9evKjR/vIx3l9FZWZmsrmg+Xw+1NTUkJSUpHTT0NBg+9GjR484jyZiYmJSpTy4FZGRkYGbN2+CYRg20FybGIbBiBEj2GBnSEgIcnNzIRQKsX37dly+fBl8Ph8rV65k8+dW1sCBA7Fr1y72c7Rv3z54eXnBysoKXl5e2LdvH+rXr4/Tp0/DxUW0qrK0+Ph4TJgwAZmZmeDz+bh48SL27t3LptXw8PDA9u3bcePGDfD5fOTn58PPzw9RUVHch/ogZe1RWQUFBSgurvkcsDV5XOzs7BAQEMCOrH748CGGDBkCW1tbuLi4YN68eSgoKMCuXbswYsQI7t0ViouLw7Vroi/dHh4eaNOmDbdKjdvcxR1NzfhoasbH5i6i/OeEEELIl+irrg3Rf2ALhN1/jX7e6zBn5kH8tPI0/Eb9jgljdyI3pxC9+7nD2sYQ5hYGaN+hPgoLBZgwdifmzPgbq5cHYPigTVg87yi6eTeGrp4mkpOzUJAv+s5S18EMM2f3Rn5eESaM3Qn/8buxYvEJDBu0CQ/DYsBTsuiWpiYPe3Zex9dDtmD18gCMGbkdi+cdhZ6+FiZO7gYtbdEAH1MzfUyZ0RPaOhpYPO8oxozcjp9Wnob/+N2Y+sNe5OcVYdJUbzi5iGYeqqqq4Bu/jmjS1A6Bp8IwuO+vWLH4BBbNPYIh/Tfg0cNYdOvRGP0HViyN2e6dQRjU5xel27070dy7AADqOphi5Oh2SEzIwP98NmPOjL/Z17143lHo6Grgm7Ed2UXJTEz1oKnJwz+nwrB88QkEB70AwzAY/rUnWrVxQuCpMPTv+TMWzT2C1csD8PWQLVg876hoJuGQVtDnBIkrqjL9oyL+2huMdh6LsWLxiQrtrwg1NVX4ftMexiaixemOHboNoVAIAwNtTPuxF/T0tbB43lG2Ly2aewT9e/6MwFNhsLM3RY/eosC/qqoKRo1pDxtbY+z5M4it7z9+N+bPPgxdXdFvNR3xv4rUdP8i5L9KLrgL8WjDXr16sf/38/ODh4eHTB0i+oG/adMmBAYGAuLA5KVLl9iFq37//XcEBARw7lU77t69izt37gDilAsODg5gFKQ10NbWxvnz5wHx69+zZ88HA4r+/v74/fffFW4//fQThg0bBldX1xqZ/iydG1eyNWrUCBs3bsS1a9fA5/Px6tUrdO7cGevXr6/QdBFlpFMVSKcwqEnSAVZFpKeNVzV9gDKGhoZsYJm7MNbHJhAI2OBcYmIimjZtCisrK6Wbt7c3Oz3933jt0dHRePbsGerUqcOmd5C4f/8+dHR05D5bkk1HRwf379+XuU9FmJqasuk7JKONg4ODsWDBAgDAqlWrqnUeZhgGvr6+iI6Oxvfffy+zGFnDhg2xZ88ehIWFwcHBAamporxZ0rmaz5w5wwZqZ8+ejS5durD3l+bp6YlNmzYBAJKSktiLTpWlqD0gTsVSURkZGR9cqKwqavK4AICrqyvCwsKwZcsWNGzYkN1vZmaGadOmITY2Fr6+vkhMTATEn23JRSFFLl68iORk0QrI3t7eVQ48V0ZvJ0s8+KYrHnzTFb2dFI/8JoQQQr4EamqqmDqrJ5as9IFtHWMEX4/A6ZP3ERX5Fm28XLD9z7EYOqINGIaBqqoKfpjaHd/7d4W2jgaCg17gbOBD1LEzxvY/x+KHqd1hY2OEuNg0mTQBHb5qgO27/NCkqR0ehsXg3JlHsLY2xOqfh7MjcLk6fNUAv2z2hbq6Gv459QCvo1PRu587du0fj/qusvljW3g4YO/f36N3P3fEx6fj9Mn7CH8cjzZeLtjz9/fo0bspGKnBXkbGuvh509eYMqMHNDV5OHfmEa5cegora0MsWemDeYv6Q1NLFDz+kMKCYqS8zVK6KcvRyjAMhv2vDZas9IGdvSmCg17g9Mn7iHieiB69m+KPPePg2tCare/oZI4x33WCUAicP/MIVy6GQygUQkdXEyt+GoopM3pAV1cTVy49xT+nHiAlJRs9ejfFnr8noEMn2e//lVGZ/vFvsrE1wrD/iQYAHDtyB1EvRd8d67taYdf+8ejdzx0pKdn459QDXLn0FLq6mpjg3xW/7RwDC8v3C3VbWPKxbecYDPTxwNvkLPxz6gHexL/DlBk9MWuuKGWZqZlsShKumuxfhPxXMUIl0bH79++jffv2MDAwQFBQULkLrX0Ma9aswezZswHxAjbc3Ia5ubno3bs3rl+/jg4dOiAwMFBmEbbacOPGDXTp0gXFxcWwtLRk22nbtm1sLlUHBwdcvnwZ9vb23LsDUu2cn58PX19f7Nmzh1vlg4RCIfz9/bFlyxZu0Qfx+XzcuHEDjRo1ktk/atQo7N27F1DS3pUleTxtbW0EBQWhefPmbFll2uDYsWMYMmQIu0jQ4cOHMWjQIG41oAKPK13etm1bnDlzBnp65f/hkRYdHY3+/fvD3d0dQ4YMQdeuXcHj8WQe90Ntl5SUhNatWyMuLk6m30rvV/TapUnatk6dOrh16xY7pT4hIQFeXl6IjY2FtbU1bt26VancuxkZGejfvz9MTU0xcOBA9OvXDzo6OuW+NmVtLn2fypJuF2WPzyV9PuC2Cz7QHwFg8+bN8Pf3x6BBg3Dw4EGZ0dfSr0ERZY/5IYpe8+zZs9nPYWVV9Tz4/PlztG3bFunp6fD398fGjRsBqTbT0tLC9evXFabvkXjz5g1at26NhIQEdO7cGQEBAZVe0E9Re1haWsr8LThz5ky56YKCg4PRqVMnlJSUYMWKFZg7dy63ihzp47t69Wr8+OOP3Coy58fKqupxKSwsxKBBg3DmzBk0aNAA169fV7hopXS9qvZFQgipjLDwrcgvkF1MR1fbErn5SVBVVUdpqQCAwp8bhJAa8CIiET+M24WOXzXAvMUDuMWE/Gv+2huMbZsvYcHSgejm3ZhbTMhnpY37HKiqVv0iQ2pqKjQ1NSsVc6oMhSN3AcDd3R1+fn6YPHnyvx7Y/RTFxsZizJgxKC4uhoqKCjZt2sS205gxY+Dj4wOIA4Djxo2r0pTiioqJicHJkycBAHXq1MGuXbtw5MiRcjdJaobMzEwcPXqU84ifroEDB8Lf3x8AUFZWhrFjx+LBgwfcahXi5OTE5rAMCwvDy5eVW+H9zp07ePLkCfbs2YMtW7agtPR9jiEJyUg7ZUpLS9n7SY+SrKiSkhKlo1rNzc3Z6dgJCQkICwvjVilXREQEbt26hWPHjmHFihUKF3qqKBUVFTZA2qFDB+Tk5EAoFFZou3btWqUDYdVRWFiIs2fPAuKUJOWl1aiI3NxcvH379oOpMXJycvD6tWjRATMzs1pLOfEhkZGRbO7uVq1acYthZGSkNCevhJ6eHmxtbQFxH5W+hljd9mjc+P0Xww9dLIiNjWVnJkjSHnyu3r17xy5I6Orqys4Q4UpMTGQ/625ubvT3mxBCCCGE1Ip3aTkYOmADpk/aj6ws2YEvb+LTERjwAAZ8bTiLU3wQQmqP0uAuwzBYtmwZJk2axC36zxMIBJg5cyY7lX/cuHHo168fW87j8bB27Vp2kZyLFy9i69at1UohUJ7Q0FAkJCQAAPr06YNRo0Zh8ODB5W6zZs1ip4IcOnToX19wq6IYhsGCBQvYUYOZmZlYuHBhlQKPBgYG+PrrrwHxwl379u2r8DHKzc3Ftm3b2P+PHDlSYTBOEoxRJiIigg0AN27cWOH0nLy8PKWpMwoKCpQeOzU1NYwaNYp9zB07dlS4nUpKSrBz5042X6mPj4/CUYIVZWBgwAbXXrx4gZQUxav8fgokwTEjIyO0b9+eW4zmzZsjLy9PLggt2fLy8tiRkmPHjoWenh4sLS1x+fJl7kPJePXqFdsXGjVqBD6fjw0bNsjloOZuDx8+hJWVaKpd69atERkZicTERBw7doy9WBAaGgpnZ2eoq6vjr7/+knleaUKhECdOiPJ2GRkZwd1dPndqeno6kpKSuLtl5OTkID5etACImpoa2wer2x4Q56mW5OK9deuW0s+sUChkc4rr6enB0dGRW6XKauq4pKenw9vbG0ZGRhg8eLDSzznEI4pjY2OBD1x0iI6OZs8JnTt3rlBKhtu3b6N+/fowMjLCmjVruMUV9q6gGO8Kaj7HMSGEEEII+fTwDXXg0coRd25FYdjAjVg09wh+Wnkakyfswf98NiExIR3fjO0I+7qiha4JIbVHaXAX4h/E1Vkc60skFArxyy+/4MiRI4B4wazly5fL/dC2s7PDqlWroKIiauIFCxYgODhYpk5NKCgowL59+wBxEGXYsGEKA4Rcbdq0YfNERkRE4OrVq9wqnyxjY2P8/PPPUFcXDYkPDAxUOj3/Q4YMGcLmuNy0aROOHz/OrSJHKBRi69atuH79OiDuA127duVWAwCcPXuWDXJxlZSU4ODBgxAKhVBXV5cJJEqPdH379q3SoGxsbCzCw8O5u1kdOnRg82cHBgZi06ZNSoNh0gICArBr1y4AgKWlJRsErypNTU12+nxycjLOnDnDrSJj//79UFNTg42NDdatW8ctrlXh4eFITk6Gm5sbO/q0qiQj5IVCIY4ePao0eCcQCLBz5062vHv37mAYBnw+Xy4HNXczMzNj+4qGhgYsLCxgaWkJY2Nj9lxgYWGB3NxcCAQCBAYGKn0dz549Y3OId+vWjb1ABXEuXYjPOadOnSq3H127do294NSqVSs2mFnd9oB4QcXWrVsDAM6fP49Xr17J3FfizZs3bH5xT09PhYvEVVVNHRd9fX3o6uoiIyMDt27dUnquKCgowI4dO4APLJgJ8YwCybGpaDoGfX19pKamIiMjA8eOHeMWV8iaWxEw2RgAk40BWHMrgltMCCGEEEK+MJKc0vMWD4CpqT6uXHqK0yfv4/GjOLRq44zf/hiDgT4tKxSfIIRUT7nB3U+JsimoH5v0Qjp8Ph+bN2+GsbExtxoAoF+/fhg3bhwAoLi4GOPGjcObN2+41arl4cOHbGC2efPmbJqBDzEwMMCAAe9zMu3evVtpAPFT1LZtWzavMQAsXbqUXexJQltbW+FoWmlmZmZYsmQJVFRUUFZWhuHDh+PPP/9UmGIB4jQKGzduxJw5cwBxEHbFihVK+0BUVBSWLFki17ZCoRDHjh1jA6jdunVD06aiVUfBGekaGhqKCxcusGUS6enpmDRpEjIzM7lFLC0tLSxYsID9/Pz4449YtmyZ3OuREAqFOHXqFHx9fVFWVgYAWLhwYY1M7e7VqxesrUWLHCxYsACXLl3iVgHEbbZ06VKUlpYiNTUVXl5e3CrVpmw0o/TIVS8vr2qng2jfvj3bdrt27cKxY8fkgqKSC0a7d+8GxKMye/fuLVOnuurUqYOOHTsCAI4fP46goCBuFSQnJ2PChAlIS0uDuro6vv/+e5mLVj179mTfy/r165UG6B8+fMjmtuXz+Rg8eDBbVhPtoampiTFjxoBhGCQlJWHJkiVyqUkKCgqwZMkS9pwwevToT/JCpZqaGts+CQkJ+P333+XOPQKBAGvWrGGD7t98843SHO4lJSVsSobKjFZ2cnJC586dAfHinFWx/OZzhbcJIYQQUrvq1bfCxaB5lG+X/Cs0NHjw7tkEew9+jxt3FuPGncW4GroAP/0yAg0a2VBgl5CP5LMJ7kqmt0I8qi8yMhIpKSlyP4Rrk3SeXQCYP39+uaukq6mpYfny5WwKgYiICCxdulRprsmgoCCMGzeuQtvOnTsB8WhMyesZMGCA0qCVIr169WKDflevXsXDhw+5VT5ZDMNgxowZbAA0KSkJc+fOlWlbPp/PBueCg4Nx7do1JCUlobCwkK0DcR7fDRs2QEVFBQKBAGPGjIGDgwNWrlyJ8PBwJCUlITIyElu3bkW9evUwZcoUdkG3nTt3okuXLjKPx7Vz5054enoiICAASUlJCA0NxahRozBs2DCUlZWBz+dj6dKlMsEnTU1NjBw5EhDnFh4+fDgWLVqEyMhIxMbGYvv27WjWrBmuXr0KQ0NDqWeT5+Hhgf3790NbWxsAsGjRIlhYWODHH3/EnTt3kJSUhNjYWOzfvx+tWrVCv3792MXCFi1ahDFjxnAesWrq1q3LtnNmZia6du0KX19fhIaGIikpCeHh4Zg3bx4aNGjApjyZMGECO0qzJkmCzPn5+Th48CBiY2Px7t07pKWl4c6dO2AYhh1lWh22trZYvXo1e/Fg2LBhGDVqFPuez507h+7du7OLdpmYmGDlypXVDipzqampYerUqeDz+SguLoa3t7dMf9qyZQvq1auHGzduAACWLVuGtm3byjyG9HvJz89H79694evri0uXLrHHb9KkSfDw8GBHoM6ZM0fmokVNtUefPn0wduxYQPz3oFOnTjh37hz7GB06dGDPkX5+fkoXXfwU9OjRgw2srl69GgMHDpRpDy8vLyxZsgQQp1mYMWOG0i/J0mlaTE1NYWFRsfxmPB4PS5YsqXB9RUy0NBTeJoQQQgghhBBSuz6b4G6jRo3YgMzhw4fh4uICJyencqek1yRunl0fH58K5SM2NjbGpk2b2MDajh07cOjQIW41QLww2o4dOyq0BQUF4e3bt2waAT6fz06/r6j69evD29sbEI8s/uuvv+RG0X3KbGxssHz5cjb1xdGjR2XSKpiamrJTyaOjo9GpUydYWVlh69atbB2IA8U//PADrly5wo50i4uLw7x58+Dm5gYrKyu4uLhg4sSJ7BRwR0dHXLlyBaNHj1YaaIF4hHGLFi3w8OFD9O/fH1ZWVvDy8mJTaTg6OuLq1ato1qwZ964YNGgQxo8fD4j739KlS+Hi4gJ7e3uMHz8ecXFxmD59OqZNm8a9q5xevXrhwYMH7IWG7Oxs/PTTT2jVqhWsrKxgb2+PkSNHsqP29PX1cfDgQSxatAg8Ho/zaFU3cOBABAQEQF9fHwCwb98+eHl5wcrKCm5ubli5ciUboJ81axbWrl1bbvtWVfv27dm0HuvWrYO9vT08PDxw8+ZNvHz5Eq6urjJByeoYOHAgdu3axbaj9Hvu0aMHLl68CIg/j0FBQWjRogXnEWqGdJCf259++OEHZGdnQ0VFBb/++iumT5+usN0VHb+uXbuyx2/Tpk0QCATg8XjYsmULZs6cKfc4NdEekrzmklGvd+/eRY8ePdjHkPRjb29vrF27tkb7cE0zMDDA7t272fd56tQpmfaQvJeBAwfi8OHD5V7Ay83NZReZs7W1rdQCja6urrhy5UqVF57b2s0dzS0M0dzCEFu7yedqJoQQQgghhBBSOz6b4G7dunVx4cIFtGvXjt2Xk5PDLrhTm7h5dp2dnSsVMPDw8MCiRYvY//v7++PBgwcydari6tWriIgQ5Tb09PSs9NR57oJbJ0+eRExMDLfaJ61v377w9fUFxMdp7ty57HtQU1PD9u3bMXnyZJlj9ezZM/a2tA4dOuD58+c4f/48Ro0ahTp16siUGxoaYtCgQTh16hSeP3+ODh06yJQr4uDggKCgIGzcuFFm9HnDhg2xZ88ePHnyRGkQkcfjsbl9u3fvzr4HHo+HXr164datW5Xqhy4uLrh58yZu3bqFSZMmwcnJiQ2MA4COjg46d+6Mffv2ITExEUOHDpULzFUXwzDo3bs34uPjsW7dOri5ucm8BisrK3z33XeIiIjAmjVrKvzeKqt169Y4e/Ysm28ZAFJTU3HlyhWUlJTAw8NDaaqNymIYBr6+vkhKSsKaNWtkpsqrqKjAw8MDf/31F8LCwuDq6ipz35rWq1cvREVFYdq0aTAzM2P3m5mZ4fvvv0dUVBQmT54MVVVVmftJfOj4OTo6Yt68eYiNjcX333+vsP/UVHsYGBjg4MGDCAwMhJeXF/s6VFRU4OXlhcDAQAQGBpYbDP1U2NjYICgoCHv27JHpkzweD927d8fly5dx+PBhdiE5ZUpLS9nZLCYmJh9MS8Pl6urKpnWorB4OFrg3qgvujeqCHg5VHwFMCCGEEEIIIaRyGOHnNFSTfLZGjRqFvXv3wsjICMHBweUGbT539+/fR/v27ZGfnw9fX98qL/ZGCCGEEPKpCwvfivyCVJl9utqWyM1PgqqqOkpLBQDo5wYhhBBCPl9t3OdAVVU0+7cqUlNToampCT09PW5RjfhsRu6SL4Ouru4nszgeIYQQQgghhBBCCCGfMwruklpXWFiItLQ0QJwqQXoaNyGEkC9DSn4RUvKLuLsJIYQQQgghhNQiirKRWvPs2TO8efMGhw4dwqVLlwDxIkmfQw5MQgghFbfqVgTMN52C+aZTWHVLlAueEEIIIYQQQkjt++KDu2vWrAHDMNXeRo0axX1oUo7CwkLMnDkTtra2GD16NIqLi6GiooLvvvuu0ov8EEII+bStCH2u8DYhhBBCCCGEkNr1xQd3yb8jOzsbiYmJ7P/r1auHw4cPo0+fPjL1CCGEfP7MtDUU3iaEEEIIIYQQUrsYoVBIy9cSQgghpMrOv07GouCnAIAlbRuie10LbhVCyBcqLHwr8gtSZfbpalsiNz8JqqrqKC0VAKCfG4QQQgj5fLVxnwNVVXXu7gpLTU2FpqYm9PT0uEU1gkbuEkIIIaRaute1wK2RnXFrZGcK7BJCCCGEEELIR0TBXUIIIYQQQkiVlJUJwHAG5qpmi3eUlaIMgAoN3CWEEELIZ06IMu6uTwalZSCEEEIIIYRUSfLVDTBuOQK8vabAO9G+wnovofnCBULtYsQ0Xo8LDvrcu/3nWGhaop/1AOyL3YP8kjxuMSGEEEI+QQwYjHOcgMDwVfBuMB1qKlVLzUBpGQghhBDyyUvMLUBibgF3NyHkS8eIf07ovN+lqin+j9r7ff91WqpaAAD1Kv4oJIQQQsjHxzCM5Ban5NNCwV1CCCGEVMvy0Gew3hII6y2BWB76jFtMCCGEEEIIIaSWUHCXEEIIIdWy6laEwtuEEEIIIYQQQmoXBXcJIYQQUi2WOqLpxtzbhBBCCCGEEEJqFwV3CSGEEFIt27q7w9PaGJ7WxtjW3Z1bTAghhBBCCCGkljBCoVDI3UkIIYQQQgghH5J8bROMWwwD75gpECfaJ2iSAN4jawj1ixHTYD0uOOhz7/ZBceFxWDf8F7h7N8Potb7cYjnK6udl5uHs1vPoNq4r9E1qZ4Xqiqir44BuFt44FP83MoszuMW14vz2CzixNoC7WykjKyP8eHQGslKyFLblv2n3zL0IO/cAM/6eijqN6nCLZVTmfUves4GZAbfosyIoFODSn5fh1skNNq7W3OJ/RVZKFtYMXof0xHRukUIDZvZD9++6cXd/dmIexyL82lP0ntSTW1Sj8jLzsOnbLWjQzhV9pvSWWvTp45AcXxMbY0z8YwI0tDWAKvTF0pJSXNlzDWe3nEV+dgEYhsGk3T/A1as+tyqpAMnfQns3O5njUllCoRDhV58iMyUT7Ya1ZfdX5lxcU24H3MHRlccxdf9kWDlbcotrnQqjAj+H8QgMXw3vBtOgVsWFUVNTU6GpqQk9vdr5LkIjdwkhhBBCCCFfnKKCYuyYtBNh5x5AWFbGLf7iaelpwdDSUGbTNtAGADAqDPjmfJkyIytDqKh+OT8PVdVU5d4/d/tS3vPBJYdx9rfzKCst5RZ9EvSMdOXanrtp6X3+aZ3ePE/AryM3Ii0+jVtUo4RCIf7ZfBZ5WfloN6ztRw/slqeyffHBuYc4tuo4BMUlaN7THZ18O8DMzpRbjXxkT4OeYcu435Cflc8t+ujcuzeDfRM7HFlxFEX5RdxiIvb5/yUjhBBCCCGE/KfVaVQHG5/8IjvSVChEWcl/L6gr0X5EO6y6sVxmk7SPcwsnLLm4UKZsxsFp0DOunRFF/4aWfVrIvX/u9qW859KSigXS/g3qWurw/3OiXNtzt/Yj2nHv+tkpKy1F2Ue4kPTq3ivcOBiMrmM6w9DSkFv8URiYGWBl0DJMOzBFZnRoZftiwosEAMDIFf+D38YxGLLAB8Y2xtxq5CMrK1Xcj0ev9cXGJ798tFG7AMDT5KG7X1dE3XuFhxcfcYuJGAV3CSGEEFJt8Tn5iM/596/uE0IIIYR8qQSFApzfcRHG1kZo2q0pt/izk5GcCQDQ1NPkFhHCsnOzg3NLJ5zfdgE573K4xYSCu4QQQgiprqUhz1Bn6z+os/UfLA15xi0mhJBqeXU/GqsH/oTxThMxweUHbB67FUlRSTJ14sLjMMltKnbP3AsAeHzlCSY3noaXdyKRnpiOHz3nYm77BchKyZK5HylfUlQSNo/digkuP2C800Qs6rYUD849hPSyLZK2Xz/iV7kps0X5RVg/4ldMcpuKuHBxUmZx3sa57RcgLT4NF3ZcwsxWszHeaSIm1p+E/XP/QnZqtszjKFJaUor98w5gvNNE/Oq7EQU5BdwqFVaQU4B1w3/BeKeJOLftvMz7ExQKsGXcbxjvNBEXfr8oU/Y2+i32/LgPE+tPwniniZjsNhWHlx9FYW4hW0daRnIm9s/9C5PdprLv948pf7IBLoj77niniWxflpaVkoW57RewfVnS9rdO3EZxQTFW9l8j19aKnnOL329ynyFIHZeXtyPZz9yPnnPx4uYLbtVac++f+5jg/APmd1qEd2/eyZQ9D47ABJcfsLDLEqQnvM/lKygU4NaJ21jUbSnGO03EeKeJWNFnFV7cfCFzvCTKSstw/0wYlvZYztaf22EB7p6+x45YVNZ3Jc5vv4DxThNxfvsFQNx2K/uvQXFBMW6duC13DMtKy/DkajjWDF6n9DkrIv75Gzy/EYFm3ZqyecSfh0RggvMP2PPjPpn3KxQK8ee03RjvNBGHlh6RehRAUCTA5rFbMb3FLCS/SmbrR96Nwrph69k+Pd5pIpb2WI77Z8JkXqekL0o+9xXpi9Ik/fzWidsAgK3jtsm0Jz5i302KSsIWv98++DlW1m+u7r0GQaFApq6kPc5vv4D4Z/Eyx33d8F/w9nUKhEIhHpx7yPbbifUnyT2vpJ13z9yLt69TsOnbLezjrBm8DpF3oxT2cUWyUrJwdOVxTHOfwT4f9/wj6fdbx20DAJxYGyBzXHbP3KvwuFblWFXm/M/T5KF1/1ZIjEzCi5svucWEgruEEEIIqa7VtyIU3iaEkOp6ev0pNozehOy0HHj6tIFNPWuEX3uKVQN+wrMbz7nVWYaWhmgzqDV0DXWhxlODR9+WaNmnBXiaVVsI5b/owbkHWN57FeKfvUGrfh5w8XDG2+i32P7DDtz4O5hbvdIKcwuwZ9Y+nFh7EnWb1kXrAa2gra+F4MOh+NV3Y7mjs0pLSnFo6REEHwpBfc96+G6LX7VytmrpaWHQ7AFQ11LHhd8v4s3zN2zZtb+C8ORKOOp71kO74e/zmz65Go4V/Vbj5rFbsHK2RNuhXjCzN8OV3Vexsv8ameAjAMQ8isGynssRfDgUfEtDePq0gZWzJe4F3seynssR8yhGpn5F6PB10HaIJ4xtjMEwDJp2bYK2Qzyhw9cBxEE/yXOa2Zuh7VAvOLo74MnVcCztuQIPzj3kPiQykjOweexW5GXmofWAVrB0soCxrQkbZOIGdpTtr6pm3Zui9cBWSItPw/nfL7IBxYzkTBxadhgAMGBGPxhZGwHiwPyuGXuwe+ZepCemo3lPdzTv6Y7k12/xy8iNuLjjkkzwq7SkFAcWHsSOSTvx9nUKmnRpjOY93ZHzLgc7p+7CgYUHK51aAACcPZzQtGsTMAwDYxtjtB3qBWcPJ0D8nCd+Ooktfr8h5nEMGndujNYDWkFQUFzp57x7+h5KS0rh0tqF3WfpZAFDK0PEPIpFXmYeuz8vMw/xz0R9OSEiQebiS1ZKFuKexsPOrQ6MrY0hFApxcccl/Dz8F8SGx6FptyZoO9QL9Vq7IOlVMnZM2onLu6+y9+f6UF/kMrQ0RNuhXmwKBte29dF2qBes64sWYavJvlueB+ceYmnPFXhyNRw2rtbw9GkDfVN9uc9xUUEx9szahx2TdiI1Pg3Ne7qzx/DQ0iPYMu43hReY7py+h9UD1yI3XfQ3zMrZElF3o7DF7zccXHwIv/v/AQMTffb8d2X3Veyff0Au4P82+i22+P2Gl3ci0bynO9w6NcLrh6/x8/BfcOnPKx8M8MY/f4PVg9bi0p+XocPXgadPGzi6O8idf1RUVdCoUyO4thUtaFenUR2Z46JIVY5VVc7/Du51oW+qj9BjN+WC6YSCu4QQQgipJmupH9TStwkhpLpy0nPRsncLLLmwEL6rvsbcU7MxdsO3EBQJcHTlMYU/AAHA1tUGwxYNgZWzJfRN9TFo9gAMmNkP2vp0jqqoooJidP+uK1YGLcPotb6YdmAKxm74FgzD4OaxWwoDGZWRn12ArNRsLL24CN9v/w6j1/pi6eXFcGrphMTIJLx++Jp7F0A8eu7iH5cQdOBGjQR2Jewb26HnRG/kZxfg1K+BEBQKEP8sHme3nIW2vij4K3mejKQMHF5+FIIiAUatGYm5p2bj6xUjMPfUbAyc1R8pMSn4e8khNgBRkFOAQ0uPoCC3EKPWjMTi8wvY/jxwVn/kZxfg3LYLEBRVLmBhbGOMIQt84NzSCTxNHnpO9GZzlmYkZ+LQ0sMoyC3E2A3fsq9x2oEpmLzHH1q6mvhrwd9IjJQdWScsE6JOozqYGzAbo9f6YsreSTD5iDlQVdVU0du/J0xsTXDjYDBe3HyJstIynN16Dsmv3qLdsLZo0rUxWz/owA2EnXsAp5ZOWHF1Kfw2joHfxjFYcn4hzOzNcGJdACJC3o/efHDuIYIPhcCppRNW3ViOCdu+k6kfciQUr+5Hs/UrysvHEz0neoOnyYNzSyd8vWIEvHw8AQCPLz/BpT+vwMTWBIvOzpfp7/U96yH4UAhCDodyH1JOflY+oh9Ew8DMABaOFux+fRN92LnVQdqbNKTFvx/tnBb/Dmlv0sAwDJJfv5U5Xya+TER2ajbqtXYBT5OHpMgknNt2Hmb2ZlhyfiH8No7B1ytGYOr+yZh1ZAbUtdRxL/Ce0gW2yuuLiti62uDrFSPg3FIUAO/k2xFfrxiBRh0afrS+m/bmHQ4tOwItXU1M3uOPOSd+hO+qr7H4wkK0HeqFlJgUXNlzDRD3s9sBd1CvtQtWB6+A38YxGL3WF8uvL0Orfh6ICH2B0xv+kQuyJkQkoM+UXlhycRF8V32NOSd+RIN2rkiJScHN47cx4+A0TDswBaPX+mL28VkwsTXB8+AIpMSkyDzO60cx0DfVZ/v4xB0TMPPwdGjra+H8tgtIihKNvlakKL8IR5YdRUZSBnpP6onF4r+n0w5MwXeb/VCQW4i9s/cj510OeBo8dPPrgk6+HQEAzXs0Y4+LIlU9VlU5/+sa6cLMzhSxT+LwLkF2VD+h4C4hhBBCqml79+ZoZ2OCdjYm2N69ObeYEEKqzMLRHP2m9wVPkwcAYBgG7t7N0Lynu9IfgKRmWDiao+PIjlBVU2X31fesBwtHC2SlZqO4oFimflV4+bSBqZ0p+38tPS24dRQFERQFK4RC4NLOyzi57hTqtXaB38YxSgO7kqnx5W2Przxh6zMMgw7/aw+nlk54ciUcIUdDcfynk8jPLkD/Gf1g28CWrfvo0mOkxqaieU93tOrvwY7mZRgGnXw7okE7Vzy/EYF48QjgmMexiHkciyadG8OjX0uZ+m2HesGxuQMyU7KQm/F+1GV1hV8NR/Krt/Ac1Abu3s3Y5wQAV6/6+Gp0J+Rm5OLuqbsy9wOAVn1byrWrZAEt7mJKyvZLSKboc9teeuOm9DC2MUb/GX0BIRCw/hRuHb+NGweDYeVsiT6Te0FFVRTGyE7Lwc3jt6DGU8PAWf2hb6rPPoaRtRGGLvQBhEDI0VCUlZZBUCTArZO3oaKqgj6Te8nV7+3fE1p6WoitgRHIEoJCAW4cDIZQKET/6X1lgrJaeloYMn8wtPW1KnTBJDstG2nx72BmZypzoUpFVQUNvFxFFySexrP745/Go1RQijaDWiM7NRuJLxMBcfqFZzcioKKqAsfmjoA4eAiGgZdPG3ZUtISRpSF0DXWR8y630hcgqqKm+64yz288R+bbTLQd1hb1Peux+1XVVOE9vhtM65ggLT4N7968Q+iRm1DXUseg2QNkRiNraKmjz9Te4Jvzce+f+3KpRIysjdB6QCu2z/I0eagnHnXdsEMD1G1qz9Y1tDSEbUMbCIoEcudXRX3c0d2BbYsnUucyrqh7rxB5Nwp1m9ij8zdfyZzTm3Zvgg4j2lc53UF1jlVlz/+aOpowsTVBXmYe0jjtTCi4SwghhJBq+srODEH/64Sg/3XCV3Zm3GJCCKmyOg3rQM9YV2afiqoK6rUS/Th+/bDyU9lJxfDN+dDQ0ZDZx9PgQc9IFyWCkhoJ8iia6mvpbMndBYgDUpf+vIwT6wKgqauJIfMHK53yDXGAxtDSsNyNpyG6aCAhnZ7h0JIjeB4cgeY93OE1RDQCE+KRwy9ui4IgzXu4s0EbCZ4mD84ezigRlCD6gWgEaFx4HIRCIZxaOsoEVgBA20AbMw9Nx+xjM2FowZcpq46Yx7EAgEYdG8q9RtH+RuBp8vD6UQyKpAJJDMPApE75U9krS89IV67tpTd9U31AKigEqfQMMY9jsW/uX1BVVcWIpcOgZyzKMwsAKTEpSIlJhX0TO1i5yPcbKxcrGJgZIOZRLHIz8pCflY/El0kwsjKCpZN8fY9+LbH+/lp0HdOZW1RlWalZeBORAL45H44tRIFUaUbWRrBysUJSVDLevZFN5cGVkZyJ/Kx88C34UNeSTTFj51YH6lrqeBbyHGWlZRAKhYh++BpGVkZo1r0p1LXU2fNlXmYeXtx6CdsGNrASf968fDyx/v5adP+uG/KzC5D4MhF3Au7ir/l/Y43POqQnlv/aatLH6ruvwkSfT5dWzjJBSQAwsTXBsitLMP63cSjIKURGUgbqNLSFWV3577nG1kZwaeWM7NRsJEe/lSkzsTZWGmy2d7OTeX8qqirQ0JI950oo6+P1PetDRVUFMU9i5VI5SESJ8/I27tJY7rUwDIN6bUR/T5+HVD61WlWPFSp5/pewdBJdHJHkiSbvybc+IYQQQgghhHwCLJ0t5X50A4CBuQEgtdI6qXl8cz40OAEkicLcQqXTsytKXUsd+lKBug8RFApw9/Q9QCh6/pCjN+WmQEtr2acFVt1YXu7m6iXKKynNvrEdun/XDUKhEDp8HfSZ0lsmICsoEiAvPQ8MwyD02E3sn3dAbnspDv6+eZ4ASI1CM7OXDwzVBkGRAFmpWeBp8mBoacgtBgDoGupAS08LBTkFKC0uYffzNHnQMdCWqVsd6lrq8P9zolzbS29jN3wr19dU1VTR64ceMLIyglAoRMeRHeSCo/nZ+SgrLcO7hHQcXXFM7jic+uU0SopLkJ2WjczkDGSlZCE3IxdGlobQ0Fbct2taflY+CnMLUVxQhFPrT8u9xqMrjyMjOROFeYVITyo/gCooEkAoFEJFVUXuvGjuYA57NzvEP32D3Iw85GXmIeZRLGwb2qBOQ1sYWxuzAbbkV2/x9vVbODRzgLbUsX55JxKLuy/FNPcZWNpzBf6cvhs3DoryazMq8ufh2vCx+m5RfhHSE9IrdB7KSc9BUX4RjKyMoKEtH3xlGAbG4tHO+Zmy50Uja8X3qSxlz62uyYOauhry0vOUXnCT/J18fuO5XP/bP+8Abp+8A4ZhkBqbKrcoZnmqc6wq0u7lKRVULEf1fwkFdwkhhBBCCCGEfPKsnC0xdf9kmNiaIOivG3h17xW3SrUV5hYi8k4kIB7h+PjyY4VBZKFQiCdXwhF8KERuex5c+RFw/xaeBg8qap9mWCD+6RtkJGcAAB5feYKMRNFtroykDAQfDpU7DjeP3UJuRi63+r8iP7sAN4/fknuNwYdC5KbyV4WGtgYcmzsgPTEdKTEpSE9Ix7uEd3B0d4C+qT5sG9ggMTIJOWnZeHX/FcpKy+Dq9T4VwfPgCPzy9Qakxqah7VAv/LDze6wOWYEtERvx45EZMLRQHLz7N32KfVeSQqimKQroV9bLO5FyfS/4UAgeXnyk8BxXkz7FY/WlodYlhBBCSLW9zszDa6kVmgkhpCa85Uxxlch6mwWIR6uR/waeJg++a0bCpZUz+kzphRJBCY6uOo68GvzbIxQKcePvYESEvoCLhzOMrIxwbtt5xIqnHkMcpNAx0gFPk4fZx2dhW9QWpdvotb6AVD/lLpJUW3gaPBiYGkBQKEBGkuKAaObbLOSm50JNXQ2MyqcXFkhPSMeJdQHgafDQpEtjpMSk4MzWcygteT9iT1tfGyqqKvDo2xK/RW6Wa3/JJskHrGeiD219baQnZaAov/o5oytC20AbmrqacGrhiF8f/Sz32qS3xl+5ce8ug6fBA8MwbNoFLpfWLhCWCfHq/ivEPolDSXEJ7Bvbg2EYOLd0Qk5aDpIikxF5NwpG1kZsjuSy0jIEHbwBYZkQX68YwS6gxTfny6URqW0fq++qa6mDb8FHcUExspUszCmhZ6QHDW0NpCemKxzZWlZahuTXor9VtRXcTU9Q/Nx5WfkQFApE5yROmhkJSbqX77b4yfU56W3agSkKRwcr87GOlSKqvI/bLz8HNde6hBBCCPlPWhz8FA7bz8Bh+xksDn7KLSaEkCpLeJEoF7wrLSnF06BnYBgG9o3tZMrIx2dgZiBabCldfrGlnHc5bNCjuhiGgYp4arh792Zw+6oRYh7HsotV1YTYx7E4t+08DMwMMHLV/9Bveh/kZxfg5M+n2MWuVFRVYO9mB0GhADGPKpbzuU5DWzAMg6i7r2SCkxAHhnb478S8jgsRFx4HvpkB1LXUkZsh356pcWnIfFuxVCSSz0b4tacKc3FGhEagrLQMlk6WcikR/m2lJaU4s/UcUmJS0OXbr+C7+mvYN7ZDyJFQPL70fuEoE1tj8C34iHsah5x3Hx6hq8vXgZWLJdIT05EUlcQtRtS9V5jWfCb+XnQIPA0ejKyNICgUII+TgqSstAwxT94H/MujZ6wHi7rmSHyZiPSE8tMufIihBR/aBtrITM6UW3AL4nykhlaGiLwThYibL8C34MPE1lhU5mwJRkWUSiTuaTzs3OpA30S0OJck1Yi6ljqs61lxHhV4E5GgNHhXGz5G32UYBg7N6gIAXt6OlDuH5GflY82gtVg1YA2gwsDQ0hBxT+OR8lr+Ak16UgZeP4iBDl8HJjai9q5pbyLeKBzh/er+KwiFQjTwclWY8xYAu2jbi5sv5d5ndX2MYyVNkuJGemFCIqL46BNCCCGEVNBPt18ovE0IIdUV/yweNw4Gsz8ahUIh7gTcxaPLj+Hc0on9cV6eovyiD65CT6pOXUsdBqb6SI5ORuTtKHa/oFCAs1vPIzs1W6Z+TeBp8tBnUi+oa6njwu8X8eb5G26VSivIEQVx87ML0OuHHjC1M2WDyBGhL3Dj7/dBZHfvZjAwM8CpX06zizJJ5GXm4ZevN2Bi/UmiHMHiAIidWx08uvwYD84/lAmwPA16hkeXHkPfWA+mdqbQ5utAW18b0WHRSIgQ5eyVPO7pXwMVBlAgbm/pIGSjTo1g4WiO0GM3EXbugcxzPg+JwJXdV6GtrwXPwa3Z/Z+Kx5eeIORIKKycLdHJtyN0+DroN60vAODEugA2SGpgZoDW/T2Q/OotAn4+BUHh+2C4UCjEvX/u4/t6/lj/v1+Rl5kHniYPrfp7oKy0DGc2n5W5cFRUUIzz2y8gPysfzi2doKKqAhNbYwiFQtw/E8YG5YVCIcLOPcDDC4/Y+3JJB+Y1tDXgNdQT+dkFOLT0iNzFqldh0ZjmPgOLuy9FamyqTBmXvok+TGyNkfz6LXIUjDbVN9GHnVsdPA16hocXHskEcC0czWFe1xwPLzxCdmq2TDBQMhq9uKAYz4Kfy/SV1NhUHFp6pFJBQW5frKyP1Xcbtm8AAzMDBB8MRvSD1+x+oVCIu4H38PpRDEztTGHlZAFPnzYoLijGsdUn5PrN6V8Ckfk2E817utdaXu387AKc335RZlGyV2HRuLL7KkztTNGoY0OZ+tIcmzvCvrEdrh8Iwv0zYTLtKSgUYN+cvzDB5QcEbjwjcz+IR96W52MdKwAozCtEWnwa9E31YUGzduRQcJcQQggh1WKr/37xCunbhBBSXdb1rBCw/jQWdV2CvXP2Y2Xf1djz4z7oGOhg0JyB5U4hVddSh7GtMfIy87Dnx/04/tNJ9kf57pl7Md5pIs5vvyBzH2X7z2+/gPFOE7F75l6Z/QTQ0tNCm0GtISwTYvsPO7B+xK/YPXMv5rafjwfnH8CtUyPuXWqEbUNbdPn2K+RnF+D4TyflpizfPX0Pc9rN/+D2PCQCQql0DG5fNUKbgaJABE+Th75TekNbX0smPYOpnSmGLRqCovxirB3yM1YNWIP98w7gt/HbMav1HLy49RL1PeuxARcdvg6GLRoCLV1N/DH5T6zsuxr75x3A+hG/Yovfb1DX4mHoQh9o6WnB2NoI7j2aIT+7AGt81mHrd9uxY9JOzG47DxnJmajbRDQKT5qlkwWEQiEOLjmMQ0uPIDUuDYYWfAxdqPg5N4zahILcQvjMGwwbVxvuwymUlZKFue0XYJLbVMSFx31wv0RxQTE2fbtFrt252x+T/0RRQTGbjgEAfOYNhp540aV6bVzQblhbmfQMDMOg69guaNi+AUKOhOJHzznYMWkn9s7Zj8Xdl+GPyX+CYRg2QAwALXo1R9uhXnhx6yVmt52HHZN2YvfMvZjfYQGeXA1H26FeaObdFADg3sMduoa6CD4UgsXdlrLnoJ1TdqHxV25yU/D1TfWha6SLp0HP2MXIhEIhWvRqjvYj2uHFrZeY1XoOfhu/HfvnHcCqAWuwdsjPyM8ugNcQL5jUMZF5PC5tA204NHNATloO3ioYQaqiqoJmXZtCKBSitKQU9m52bABXW18bFk7mEAqF0NbXgmNzB5n7tRnQCowKg5PrTsn0lQWdF8PAVB91m9ZFbkYuslLKD/Yp6ouVVdN9VxlTO1MMWzgEBbmFWDvkZ6wf8Sv2zzuAlX1X4+9Fh2Bmb4aBM/tDRVUF7Ue0Q6t+Hgr7ze2AO3Bq6YTe/j2Vjp6tLjV1NdwNvIf5HRZg98y9WD/iV6wd8jOK8osxcFZ/pQuaQXz+Gb5kGHQMdPDH5D+xuPsy7J2zHzsm7cSPnnMQciQUlo4W8Bz0PgArmUEQciQUu2bswePLj2UeU+JjHSsAyEzORFJUMuo0tIWBmWhRVfJe7fQ8QgghhPxn/O7dHJ3qmKFTHTP87t2cW0wIIVXm0bclpu6fDJ4GD6FHbiIxMglth3hi3uk5sHMT5YtUhmEYeI/vjrpN6+L1w9e4uvdalQIN5MPaDW+L7zb7wdzBHC/vROJ2wB3YNbbDj8dmwamFI7d6jZAE7aycLfE8OAK3T96RKS8tKUVGUsYHN0GRgE3HoK2vhb5TessE7WxcbdBtXFe59AzNujfF/MA5cOvUCG+eJyD4UAgeXXoM87pmGL3WF+O3joOWnhb7OPZN7LHgzHy0HeKJxMgkBB8KwauwaLh7N8Ock7NhLw7aMgyDQbMHYNiiIdA30cfjy4/x8MIjtOrbEjP+nqowz3Sbga3RvIc73ka/xdW919iRzK5e9dnnTIlJYZ/TrVMjLDwzD20Gta72Ik0VkZOeK9fu3C07NRulghI2HUO7YW1Rr40L+xgqqiroPq4rTGxNZNIzaOlpYfzWcRi60Aeaelq4fyYMoUduIjMpA20GtcbCM/PQrLsoWAsAqmqqGLF0GPw2joGRlRHunwnDrRO3wdNSx6g1IzF8yVA2z6yVsyVmHp6ORh0bIu3NO4QeuYniIgHGbRqLnhN7yLWdgZkBBs7qD12+Dh6ce4jbJ++gKL8IqmqqGLZoCPw2joF5XTM8uvQYwYdC8OZ5Atw6NcKPx2aiy7dfyT2eIi37tICqmiqeXAlXOJq2TiNb6PB1oKKqAsfm7z97KqoqcHQXBXRt6tvAmJM+wO0rN0zZ4w8rZ0vEP3+D4EMhyErLxui1vpi06wfUb+OC4oJiJL8qP82Ksr5YWR+r7zbzboqFZ+ahUceGiLwXheBDIUiJSUHHrztg1pHpMLI2AgBoaKlj1E8j4bdxDExtTWT6zdCFPpi86wfom4pGSdcGh6Z1MfPQNJg7WuDWiduIvBeFRh0bYn7gHJn+rYydWx3MOz0HbYd4IjMpA6FHbuL+mTBo6mlh4Kz+mHVkBvteAcC6vjX6TOkNoVCI2yfv4N4/siN+pX2sYxV5Jwp5mXlo3b+V0vzC/2WMUNkRIoQQQgghhJByJF/bBOMWw8A7ZgqIB+0JmiSA98gaQv1ixDRYjwsOtfeD93NRV8cB3Sy8cSj+b2QWf7zclYSQL0tpSSn+nLobr8KiMePQtFrL8Uo+DVkpWVgzeB1MbIwx8Y8J5c5W+ZIJCgX4bcJ2ZCZnYur+yeyI/o9BhVGBn8N4BIavhneDaVBTqVru4NTUVGhqakJPr3ZeO43cJYQQQgghhBBCCPnEqaqpopNvB+Sm5yL8aji3mJAvUtS9V4gIfYHu47t91MDu54SCu4QQQgghhBBCCCGfAccWjmj/v3a4tv86stPkF1Yj5EsiKBTgyt6rqNfGBU27NuEWEzEK7hJCCCGk2qIychGVkcvdTQghhBBCahDDMOgxoTtUVFRwbd81pblQCfkSPLr8GDGPYjFwVv//bFqKiqDgLiGEEEKqZWHwUzj/fhbOv5/FwuCn3GJCCCGEEFKD9Iz1sPDsfPSd2qdGFqsinyYDMwOsDFqGaQem/GcDmy16Ncfa26th28CWW0SkUHCXEEIIIdWy9vYLhbcJIYQQQgghhNQuCu4SQgghpFrs9LUV3iaEEEIIIYQQUrsouEsIIYSQavmjRwt0sTdHF3tz/NGjBbeYEEIIIYQQQkgtoeAuIYQQQqqlrY0JLg5tj4tD26OtjQm3mBBCCCGEEEJILaHgLiGEEEIIIYQQQgghhHyGKLhLCCGEEEIIIYQQQgghnyEK7hJCCCGk2l6k5+BFeg53NyGEEEIIIYSQWkTBXUIIIYRUy7ygJ6i/4xzq7ziHeUFPuMWEEEIIIYQQQmoJBXcJIYQQUi0/33mp8DYhhBBCCCGEkNpFwV1CCCGEVIsDX1fhbULIfwDDcPcQQgghhJCPiIK7hBBCCKmWP3o0R/e65uhe1xx/9GjOLSaEEEIIIYQQUksouEsIIYSQavG0NsG5Ie1xbkh7eFqbcIsJIf8FJdwdAITcHaRMWMbdRQghhJBPlFD4eXyZYYSfyyslhBBCCCGEfDKEpcXIenkdmoYuwGtzXDQ+h56FfVCalIwQyzAYlBiCMczFq7JkGOrZ4F1hKprwm3Ef5j+hpKwELvr18DjzIXTV9LjFVZKZmQkdHR3weDxuESGEEEJqQGFZIVx0nZGQ+QSW+vWgo27IrVIhqamp0NTUhJ5ezXwH4KLgLiGEEEIIIaTqhGUAo4Id0dvwbV0/qDKquPz2Iqy0rKHP00dg4in0suqDy28vYpT9t9x7kypKSkoCn8+HlpYWt4gQQgghn5DaDu5SWgZCCCGEVNvTtCw8Tcvi7iaE/Bcw9JOCEEIIIeTfQt/ECCGEEFItc64/QaOdF9Bo5wXMuf6EW0wIIYQQQgghpJZQcJcQQggh1fLL3UiFtwkhhBBCCCGE1C4K7hJCCCGkWhwNdRTeJoQQQgghhBBSuyi4SwghhJBq+bNHS/RytEQvR0v82aMlt5gQQgghhBBCSC35P/czI2FqKrf0AAAAAElFTkSuQmCC" | |
| } | |
| }, | |
| "cell_type": "markdown", | |
| "id": "d89d4817-0015-4569-b7be-f23610d38f84", | |
| "metadata": {}, | |
| "source": [ | |
| "" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e125bce8-eaac-4979-bbe5-ba1b89b6ab93", | |
| "metadata": {}, | |
| "source": [ | |
| "### `ppermute` Loop Version\n", | |
| "Use ppermute ring shift to overlap communication and compuation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "9231dffd-5a10-4c4e-b8c0-c08e906c0a5f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7.17 ms ± 202 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(\"dp\", None)), out_specs=P(\"dp\", None)\n", | |
| ")\n", | |
| "def matmul_fsdp_allgather_overlapped(\n", | |
| " x_block: Float[Array, \"batch_local n_in\"],\n", | |
| " w_block: Float[Array, \"n_in_local n_out\"],\n", | |
| ") -> Float[Array, \"batch_local n_out\"]:\n", | |
| " # f32[1024@dp, 4096] @ f32[4096@dp, 256]\n", | |
| " # f32[256{dp=4}, 4096] f32[1024{dp=4}, 256]\n", | |
| " num_devices = jax.lax.axis_size(\"dp\")\n", | |
| " device_idx = jax.lax.axis_index(\"dp\")\n", | |
| "\n", | |
| " x_block_v_slice_size = x_block.shape[1] // num_devices\n", | |
| "\n", | |
| " def _get_x_block_v_slice(i: int) -> Float[Array, \"batch_local n_in_local\"]:\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=i * x_block_v_slice_size,\n", | |
| " slice_size=x_block_v_slice_size,\n", | |
| " axis=1,\n", | |
| " )\n", | |
| "\n", | |
| " # in a forward ring shift, the program receives the previous w_block in a roll\n", | |
| " # we vertically slice the corresponding chunk of x_block for matrix multiplication in col @ row format\n", | |
| " # and reduce the results on the fly\n", | |
| "\n", | |
| " # start from the current col chunk @ current w row: f32[batch_local, n_out]\n", | |
| " out_block = _get_x_block_v_slice(device_idx) @ w_block\n", | |
| " # ring shift forward for num_devices - 1 steps\n", | |
| " for i in range(1, num_devices):\n", | |
| " w_block = ring_shift(w_block, axis_name=\"dp\", offset=1)\n", | |
| " out_block += _get_x_block_v_slice((device_idx - i) % num_devices) @ w_block\n", | |
| " return out_block\n", | |
| "\n", | |
| "\n", | |
| "assert jnp.allclose(matmul_fsdp_allgather_overlapped(x, w), y, atol=1e-4)\n", | |
| "%timeit matmul_fsdp_allgather_overlapped(x,w).block_until_ready() # fsdp is faster than non-paralell version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b5c9b6f0-0c34-436e-8356-b109f10fd893", | |
| "metadata": {}, | |
| "source": [ | |
| "The above implementation allows overlapping communication with computation, and also avoids gathering the whole large weight `w` of size `[n_in, n_out]` onto each device. But on TPU it uses only **half the interconnect bandwidth** by permuting in only one direction along the ring. To permute bidirectionally, we just split the blocks in half and send each half in each direction:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "id": "77ee42a6-cb84-4411-83cf-719f6065d283", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7.37 ms ± 132 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(\"dp\", None)), out_specs=P(\"dp\", None)\n", | |
| ")\n", | |
| "def matmul_fsdp_allgather_overlapped(\n", | |
| " x_block: Float[Array, \"batch_local n_in\"],\n", | |
| " w_block: Float[Array, \"n_in_local n_out\"],\n", | |
| ") -> Float[Array, \"batch_local n_out\"]:\n", | |
| " # f32[1024@dp, 4096] @ f32[4096@dp, 256]\n", | |
| " # f32[256{dp=4}, 4096] f32[1024{dp=4}, 256]\n", | |
| " num_devices = jax.lax.axis_size(\"dp\")\n", | |
| " device_idx = jax.lax.axis_index(\"dp\")\n", | |
| "\n", | |
| " x_block_v_slice_size, _remainder = divmod(x_block.shape[1], num_devices * 2)\n", | |
| " assert _remainder == 0, f\"{x_block.shape[1]=} must be divisible by {2*num_devices=}\"\n", | |
| "\n", | |
| " def _get_x_block_v_slice(\n", | |
| " i: int, is_right_half: bool = False\n", | |
| " ) -> Float[Array, \"batch_local n_in_local//2\"]:\n", | |
| " \"\"\"slice the local x_block vertically into num_devices chunks, and split each chunk into left and right halves\"\"\"\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=(2 * i + is_right_half) * x_block_v_slice_size,\n", | |
| " slice_size=x_block_v_slice_size,\n", | |
| " axis=1,\n", | |
| " )\n", | |
| "\n", | |
| " # Initialization: split w_block horizontally into two halves (the rows) and\n", | |
| " # get the corresponding vertical chunks of the local x_block (the cols) to run col @ row matrix multiplication\n", | |
| " # and reduce on the fly\n", | |
| " w_block_up, w_block_bottom = jnp.split(w_block, 2, axis=0)\n", | |
| " x_block_chunk_left = _get_x_block_v_slice(device_idx, is_right_half=False)\n", | |
| " x_block_chunk_right = _get_x_block_v_slice(device_idx, is_right_half=True)\n", | |
| " out_block = x_block_chunk_left @ w_block_up\n", | |
| " out_block += x_block_chunk_right @ w_block_bottom # f32[batch_local, n_out]\n", | |
| "\n", | |
| " # bidirectional ring shift w_block for num_devices - 1 steps to finish the whole x_block @ w_block\n", | |
| " for i in range(1, num_devices):\n", | |
| " # ring shift forward and receive w_block_up from prev device\n", | |
| " w_block_up = ring_shift(w_block_up, axis_name=\"dp\", offset=1)\n", | |
| " # ring shift backward and receive w_block_bottom from next device\n", | |
| " w_block_bottom = ring_shift(w_block_bottom, axis_name=\"dp\", offset=-1)\n", | |
| " x_block_chunk_left = _get_x_block_v_slice(\n", | |
| " (device_idx - i) % num_devices, is_right_half=False\n", | |
| " )\n", | |
| " x_block_chunk_right = _get_x_block_v_slice(\n", | |
| " (device_idx + i) % num_devices, is_right_half=True\n", | |
| " )\n", | |
| " out_block += x_block_chunk_left @ w_block_up\n", | |
| " out_block += x_block_chunk_right @ w_block_bottom\n", | |
| " return out_block\n", | |
| "\n", | |
| "\n", | |
| "assert jnp.allclose(matmul_fsdp_allgather_overlapped(x, w), y, atol=1e-4)\n", | |
| "%timeit matmul_fsdp_allgather_overlapped(x,w).block_until_ready() # fsdp is faster than non-paralell version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0661e107-65e2-4a2c-8b0a-40806684b1fa", | |
| "metadata": {}, | |
| "source": [ | |
| "Profile the bidirectional ring shift version to verify the all-gather disappears. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "08ed557c-643b-445b-9919-8cbfa30bb24a", | |
| "metadata": { | |
| "collapsed": true, | |
| "jupyter": { | |
| "outputs_hidden": true | |
| }, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "E0130 16:19:40.255039 3605 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n", | |
| "E0130 16:19:40.271309 3605 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <iframe id=\"tensorboard-frame-d812e4d4352a8fe5\" width=\"100%\" height=\"800\" frameborder=\"0\">\n", | |
| " </iframe>\n", | |
| " <script>\n", | |
| " (function() {\n", | |
| " const frame = document.getElementById(\"tensorboard-frame-d812e4d4352a8fe5\");\n", | |
| " const url = new URL(\"/\", window.location);\n", | |
| " const port = 6007;\n", | |
| " if (port) {\n", | |
| " url.port = port;\n", | |
| " }\n", | |
| " frame.src = url;\n", | |
| " })();\n", | |
| " </script>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "with jax.profiler.trace(log_dir=\"./jax_profiler\"):\n", | |
| " _ = matmul_fsdp_allgather_overlapped(x, w).block_until_ready()\n", | |
| " # ^ need block until ready within profiler context\n", | |
| "\n", | |
| "%tensorboard --logdir=\"./jax_profiler\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "88f484f3-9189-45e7-916d-daea6233f07a", | |
| "metadata": {}, | |
| "source": [ | |
| "### Matmul Reduce Scatter\n", | |
| "Use Case: TP forward step; FSDP backward step" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "ac2a8cc9-2180-4579-9993-8d904c077d66", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y)=ShapedArray(float32[1024@(dp,tp),256])\n", | |
| "True\n", | |
| "2.85 ms ± 101 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=(P(None, (\"dp\", \"tp\")), P((\"dp\", \"tp\"), None)),\n", | |
| " out_specs=P((\"dp\", \"tp\"), None),\n", | |
| ")\n", | |
| "def matmul_psumscatter(\n", | |
| " x_block: Float[Array, \"batch n_in_block\"],\n", | |
| " w_block: Float[Array, \"n_in_block n_out\"],\n", | |
| ") -> Float[Array, \"batch_scatter n_out\"]:\n", | |
| " # f32[1024, 4096@(dp,tp)] @ f32[4096@(dp,tp), 256]\n", | |
| " # f32[1024, 512{dp=4,tp=2}] @ f32[512{dp=4,tp=2}, 256] -> f32[1024, 256]{dp,tp}\n", | |
| " # -> psum -> f32[1024,256] -> scatter -> f32[1024@(dp,tp), 256]\n", | |
| " out_block = x_block @ w_block\n", | |
| " out_block_reduce_scatterred = jax.lax.psum_scatter(\n", | |
| " out_block, axis_name=(\"dp\", \"tp\"), scatter_dimension=0, tiled=True\n", | |
| " )\n", | |
| " return out_block_reduce_scatterred\n", | |
| "\n", | |
| "\n", | |
| "y = matmul_psumscatter(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(jnp.allclose(y, x @ w, atol=1e-4))\n", | |
| "%timeit _ = matmul_psumscatter(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "38629999-d23c-4c5b-9d0a-828a6cbd434a", | |
| "metadata": {}, | |
| "source": [ | |
| "But the scattering communication must wait for the entire local matrix multiplications to finish before it can start. To get communication/computation overlap, we can implement `psum_scatter` as N-1 ppermute, then interleave the communication steps with local matrix multiplications." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "95500c5d-a7cf-4ba9-8dc0-e5de76692252", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y)=ShapedArray(float32[1024@(dp,tp),256])\n", | |
| "True\n", | |
| "4.72 ms ± 104 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=(P(None, (\"dp\", \"tp\")), P((\"dp\", \"tp\"), None)),\n", | |
| " out_specs=P((\"dp\", \"tp\"), None),\n", | |
| ")\n", | |
| "def matmul_psumscatter_overlapped(\n", | |
| " x_block: Float[Array, \"batch n_in_block\"],\n", | |
| " w_block: Float[Array, \"n_in_block n_out\"],\n", | |
| ") -> Float[Array, \"batch_scatter n_out\"]:\n", | |
| " # f32[1024, 4096@(dp,tp)] @ f32[4096@(dp,tp), 256]\n", | |
| " # f32[1024, 512{dp=4,tp=2}] @ f32[512{dp=4,tp=2}, 256] -> f32[1024, 256]{dp,tp}\n", | |
| " # -> psum -> f32[1024,256] -> scatter -> f32[1024@(dp,tp), 256]\n", | |
| " num_devices = jax.lax.axis_size((\"dp\", \"tp\"))\n", | |
| " device_idx = jax.lax.axis_index((\"dp\", \"tp\"))\n", | |
| " x_block_h_slice_size, _remainder = divmod(x_block.shape[0], num_devices)\n", | |
| " assert _remainder == 0\n", | |
| "\n", | |
| " def _get_x_block_h_slice(i: int):\n", | |
| " \"\"\"slice local x_block into dp x tp rows and get the ith slice\"\"\"\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=i * x_block_h_slice_size,\n", | |
| " slice_size=x_block_h_slice_size,\n", | |
| " axis=0,\n", | |
| " )\n", | |
| "\n", | |
| " # each device initializes its matmul summation from its last peer so we can ring shift back\n", | |
| " y_block_scatterred = (\n", | |
| " _get_x_block_h_slice((device_idx - (num_devices - 1)) % num_devices) @ w_block\n", | |
| " )\n", | |
| " # ring shift the summation back and accumulate the local matmul on the fly to reduce onto device_idx eventually\n", | |
| " for i in range(num_devices - 2, -1, -1):\n", | |
| " y_block_scatterred = ring_shift(\n", | |
| " y_block_scatterred, axis_name=(\"dp\", \"tp\"), offset=-1\n", | |
| " )\n", | |
| " y_block_scatterred += (\n", | |
| " _get_x_block_h_slice((device_idx - i) % num_devices) @ w_block\n", | |
| " )\n", | |
| " return y_block_scatterred\n", | |
| "\n", | |
| "\n", | |
| "y = matmul_psumscatter_overlapped(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(jnp.allclose(y, x @ w, atol=1e-4))\n", | |
| "%timeit _ = matmul_psumscatter_overlapped(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b52a612c-b4fb-49b4-8f97-400cb7ad3d49", | |
| "metadata": {}, | |
| "source": [ | |
| "**Bidirectional ring shift**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "2e080710-32bf-437d-95b0-c1343af2f1a4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y)=ShapedArray(float32[1024@(dp,tp),256])\n", | |
| "True\n", | |
| "9.17 ms ± 143 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " in_specs=(P(None, (\"dp\", \"tp\")), P((\"dp\", \"tp\"), None)),\n", | |
| " out_specs=P((\"dp\", \"tp\"), None),\n", | |
| ")\n", | |
| "def matmul_psumscatter_overlapped(\n", | |
| " x_block: Float[Array, \"batch n_in_block\"],\n", | |
| " w_block: Float[Array, \"n_in_block n_out\"],\n", | |
| ") -> Float[Array, \"batch_scatter n_out\"]:\n", | |
| " # f32[1024, 4096@(dp,tp)] @ f32[4096@(dp,tp), 256]\n", | |
| " # f32[1024, 512{dp=4,tp=2}] @ f32[512{dp=4,tp=2}, 256] -> f32[1024, 256]{dp,tp}\n", | |
| " # -> psum -> f32[1024,256] -> scatter -> f32[1024@(dp,tp), 256]\n", | |
| " num_devices = jax.lax.axis_size((\"dp\", \"tp\"))\n", | |
| " device_idx = jax.lax.axis_index((\"dp\", \"tp\"))\n", | |
| " # horizontally slice x_block into num_devices rows and split each row into top and bottom halves\n", | |
| " x_block_h_slice_size, _remainder = divmod(x_block.shape[0], num_devices * 2)\n", | |
| " assert _remainder == 0, f\"{x_block.shape[0]=} must divide {num_devices * 2=}\"\n", | |
| "\n", | |
| " def _get_x_block_h_slice(i: int, is_bottom_half: bool = False):\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=(2 * i + is_bottom_half) * x_block_h_slice_size,\n", | |
| " slice_size=x_block_h_slice_size,\n", | |
| " axis=0,\n", | |
| " )\n", | |
| "\n", | |
| " # For ring shift backward: each device initializes its matmul summation from its last peer and\n", | |
| " # ring shift the summation back and accumulate the local matmul on the fly to reduce onto device_idx eventually\n", | |
| " x_block_h_slice_top = _get_x_block_h_slice(\n", | |
| " (device_idx - (num_devices - 1)) % num_devices, is_bottom_half=False\n", | |
| " )\n", | |
| " y_block_scatterred_top = x_block_h_slice_top @ w_block\n", | |
| " # For ring shift forward: each device initializes its matmul summation from its next peer and\n", | |
| " # ring shift the summation forward and accumulate the local matmul on the fly to reduce onto device_idx eventually\n", | |
| " x_block_h_slice_bottom = _get_x_block_h_slice(\n", | |
| " (device_idx - 1) % num_devices, is_bottom_half=True\n", | |
| " )\n", | |
| " y_block_scatterred_bottom = x_block_h_slice_bottom @ w_block\n", | |
| "\n", | |
| " for i, j in zip(range(num_devices - 2, -1, -1), range(2, num_devices + 1)):\n", | |
| " y_block_scatterred_top = ring_shift(\n", | |
| " y_block_scatterred_top, axis_name=(\"dp\", \"tp\"), offset=-1\n", | |
| " )\n", | |
| " y_block_scatterred_bottom = ring_shift(\n", | |
| " y_block_scatterred_bottom, axis_name=(\"dp\", \"tp\"), offset=1\n", | |
| " )\n", | |
| " x_block_h_slice_top = _get_x_block_h_slice(\n", | |
| " (device_idx - i) % num_devices, is_bottom_half=False\n", | |
| " )\n", | |
| " y_block_scatterred_top += x_block_h_slice_top @ w_block\n", | |
| " x_block_h_slice_bottom = _get_x_block_h_slice(\n", | |
| " (device_idx - j) % num_devices, is_bottom_half=True\n", | |
| " )\n", | |
| " y_block_scatterred_bottom += x_block_h_slice_bottom @ w_block\n", | |
| " return jnp.concat((y_block_scatterred_top, y_block_scatterred_bottom), axis=0)\n", | |
| "\n", | |
| "\n", | |
| "y = matmul_psumscatter_overlapped(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(jnp.allclose(y, x @ w, atol=1e-4))\n", | |
| "%timeit _ = matmul_psumscatter_overlapped(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f7967a90-aa18-4e97-829e-082337c90e81", | |
| "metadata": {}, | |
| "source": [ | |
| "## Practical Examples" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b082f6a8-970c-4651-b7f1-fedbb9c5189d", | |
| "metadata": {}, | |
| "source": [ | |
| "### Distributed Training of Neural Networks\n", | |
| "\n", | |
| "Setup a basic Neural Net example" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "id": "44a4bba7-55c5-4ecc-adcb-867cbb5571a9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from dataclasses import dataclass, field\n", | |
| "from typing import Any, Callable\n", | |
| "\n", | |
| "\n", | |
| "# Register dataclass as PyTree\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass\n", | |
| "class Linear[T: (jax.Array, P)]:\n", | |
| " w: T\n", | |
| " b: T\n", | |
| "\n", | |
| "\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass\n", | |
| "class Batch[T: (jax.Array, P)]:\n", | |
| " inputs: T\n", | |
| " targets: T\n", | |
| "\n", | |
| "\n", | |
| "# Type aliases\n", | |
| "type LinearParams = Linear[jax.Array]\n", | |
| "type LinearPSpec = Linear[P]\n", | |
| "type BatchData = Batch[jax.Array]\n", | |
| "type BatchPSpec = Batch[P]\n", | |
| "\n", | |
| "\n", | |
| "# @jax.tree_util.register_pytree_node_class\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass\n", | |
| "class Model:\n", | |
| " \"\"\"JAX-compatible Model container -- Registered as a PyTree\"\"\"\n", | |
| "\n", | |
| " # Dynamic Fields (JAX Pytree Children)\n", | |
| " layers: list[LinearParams] | list[LinearPSpec]\n", | |
| "\n", | |
| " # Static Fields (JAX Aux Data)\n", | |
| " layer_sizes: tuple[int, ...] = field(metadata={\"static\": True})\n", | |
| "\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " layer_sizes: tuple[int, ...],\n", | |
| " key: jax.Array | None = None,\n", | |
| " layers: list[LinearParams] | list[LinearPSpec] | None = None,\n", | |
| " ):\n", | |
| " self.layer_sizes = layer_sizes\n", | |
| " # Not a dataclass field so it is Non-Pytree child\n", | |
| " self.key = key\n", | |
| "\n", | |
| " if layers is not None:\n", | |
| " self.layers = layers\n", | |
| " elif key is not None:\n", | |
| " self._init_weights()\n", | |
| " else:\n", | |
| " self.layers = []\n", | |
| "\n", | |
| " @property\n", | |
| " def num_layers(self) -> int:\n", | |
| " return len(self.layer_sizes) - 1\n", | |
| "\n", | |
| " def _init_weights(self) -> None:\n", | |
| " \"\"\"Random init self.layers as a list of LinearParams\"\"\"\n", | |
| " assert self.key is not None, \"Did you set key in model?\"\n", | |
| " keys = jax.random.split(self.key, self.num_layers)\n", | |
| " self.layers = []\n", | |
| " for k, n_in, n_out in zip(keys, self.layer_sizes[:-1], self.layer_sizes[1:]):\n", | |
| " k1, k2 = jax.random.split(k)\n", | |
| " w = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", | |
| " b = jax.random.normal(k2, (n_out,))\n", | |
| " self.layers.append(Linear(w, b))\n", | |
| "\n", | |
| " def predict(\n", | |
| " self,\n", | |
| " inputs: Float[Array, \"b in\"],\n", | |
| " linear_fwd_fn: Callable[[jax.Array, LinearParams], jax.Array],\n", | |
| " ) -> Float[Array, \"b out\"]:\n", | |
| " x = inputs\n", | |
| " for layer in self.layers:\n", | |
| " x = linear_fwd_fn(x, layer)\n", | |
| " x = jax.nn.relu(x)\n", | |
| " return x\n", | |
| "\n", | |
| " def init_random_batch(self, key: jax.Array, batch_size: int = 32) -> BatchData:\n", | |
| " k1, k2 = jax.random.split(key)\n", | |
| " inputs = jax.random.normal(k1, (batch_size, self.layer_sizes[0]))\n", | |
| " targets = jax.random.normal(k2, (batch_size, self.layer_sizes[-1]))\n", | |
| " return Batch(inputs, targets)\n", | |
| "\n", | |
| " # if using @jax.tree_util.register_pytree_node_class instead of @jax.tree_util.register_dataclass\n", | |
| " # --- Reference Implementation of JAX Pytree Conversion ---\n", | |
| " # def tree_flatten(self) -> tuple[tuple[\"Dynamic\", ...], tuple[\"Static\", ...]]:\n", | |
| " # return (self.layers,), (self.layer_sizes,)\n", | |
| "\n", | |
| " # @classmethod\n", | |
| " # def tree_unflatten(\n", | |
| " # cls,\n", | |
| " # aux_data: tuple[\"Static\", ...],\n", | |
| " # children: tuple[\"Dynamic\", ...],\n", | |
| " # ) -> \"Model\":\n", | |
| " # (layers,) = children\n", | |
| " # (layer_sizes,) = aux_data\n", | |
| " # return cls(layer_sizes=layer_sizes, layers=layers)\n", | |
| "\n", | |
| "\n", | |
| "# Initialize model and data\n", | |
| "key_model, key_batch = jax.random.split(jax.random.key(0))\n", | |
| "model = Model(layer_sizes=(32, 128, 256, 128, 4), key=key_model)\n", | |
| "batch = model.init_random_batch(key=key_batch, batch_size=16)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3489e9a9-9d93-4e7c-b644-bc2424fe27af", | |
| "metadata": {}, | |
| "source": [ | |
| "Baseline: Single-device, no parallelism " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "79806266-8fb9-4d54-8f90-dfe3f37a29ba", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Baseline Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def linear(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " return x @ layer.w + layer.b\n", | |
| "\n", | |
| "\n", | |
| "# Calculate grad(loss) wrt to params. argnums=0 specifies that the gradient of Loss is calculated wrt the positional arg 0.\n", | |
| "@jax.jit\n", | |
| "@partial(jax.value_and_grad, argnums=0)\n", | |
| "def loss_baseline(model: Model, batch: BatchData) -> jax.Array:\n", | |
| " preds = model.predict(batch.inputs, linear_fwd_fn=linear)\n", | |
| " # MSE loss\n", | |
| " return jnp.mean(jnp.sum((preds - batch.targets) ** 2, axis=-1))\n", | |
| "\n", | |
| "\n", | |
| "loss_baseline_, grads_ = loss_baseline(model, batch)\n", | |
| "print(\"Baseline Loss:\", loss_baseline_)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "id": "958ada43-0916-4fc7-984b-cd7c348d1636", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "459 μs ± 21.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_baseline(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b5f60638-69cd-49ab-ac03-89fd7ba82534", | |
| "metadata": {}, | |
| "source": [ | |
| "Note: we omit the weight update step for simplicity." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "224d9f35-9939-4ccd-9336-fdbc3f40609e", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Data Parallelism \n", | |
| "Shard data along the batch dim over the `dp` mesh axis and replicate model weights across the `dp` mesh axis.\n", | |
| "\n", | |
| "* Forward: if you need the (total or average) loss, you need an all-reduce (psum/pmean) over the `dp` mesh axis for total mean loss \n", | |
| "* Backward: gradients need to be all-reduce-sum (psum) to keep model replicas' weights in sync." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "781aace1-1042-4543-a678-f8e555224c77", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "DP Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# fully replicate model weights\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(), b=P())\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along batch dim and put on dp x tp mesh axis\n", | |
| "batch_pspec: BatchPSpec = Batch(\n", | |
| " inputs=P((\"dp\", \"tp\"), None), targets=P((\"dp\", \"tp\"), None)\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_dp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16 @ (dp=4, tp=2), 32] -> f32[2{dp,tp}, 32]\n", | |
| " # w, b: fully replicated: x @ w + b -> y f32[2{dp,tp}, 4]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear)\n", | |
| " # local per-example MSE loss\n", | |
| " loss_per_example = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " # local mean MSE loss\n", | |
| " loss_local = jnp.mean(loss_per_example)\n", | |
| " # all reduce mean across devices -> total mean mse loss\n", | |
| " return jax.lax.pmean(loss_local, (\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "loss_dp_, grads_dp_ = loss_dp(model, batch)\n", | |
| "print(\"DP Loss:\", loss_dp_)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "58107fd8-3790-4b88-892e-02aa8301e124", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.11 ms ± 94.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_dp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "id": "fac73302-8cb3-4850-bb41-1ce3fbc17f7d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def check_all_close(a: PyTree[Array], b: PyTree[Array], atol=1e-5) -> bool:\n", | |
| " return jax.tree.all(jax.tree.map(partial(jnp.allclose, atol=atol), a, b))\n", | |
| "\n", | |
| "\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_dp_, grads_dp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "90392021-fe5d-45fc-9106-82e31c3fb793", | |
| "metadata": {}, | |
| "source": [ | |
| "#### FSDP\n", | |
| "In addition to DP, shard the model weights over the `dp` mesh axis. \n", | |
| "* Forward: Need an All-Gather to (Pre)-fetch weights before layer forward. Release the all-gathered weights once a later forward is done rather than keeping them in device memory. \n", | |
| "* Backward: Need an All-Gather (re-materialization) weights before calculating gradients (rematerialization) and need a reduce-scatter to keep gradients in sync and sharded across dp devices.\n", | |
| "\n", | |
| "A few words about `jax.checkpoint`: \n", | |
| "* Standard Backprop: If you run any `jax.grad(loss)(x,y)` for training, JAX automatically saves **all** intermediate activations such as `h =x @ w` from the forward pass to calculate gradients during the backward pass. This is fast but memory-intensive, especially if you have a deep architecture comprising many layers.\n", | |
| "* Activation Checkpointing: JAX saves only the `(inputs, outputs)` of the checkpointed function. During the backward pass, it recomputes the intermediate activations needed on the fly for calculating gradients, thus using more compute to exchange for memory. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "7b6de1dc-b642-46c8-a524-ef07570a9578", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FSDP Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Shard model weights over dp x tp axis\n", | |
| "linear_pspec: LinearPSpec = Linear(\n", | |
| " w=P((\"dp\", \"tp\"), None),\n", | |
| " b=P(\"tp\"), # can also be P() to replicate bias since its cheap\n", | |
| ")\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along batch dim over dp x tp mesh axis. Same as DP\n", | |
| "batch_pspec: BatchPSpec = Batch(\n", | |
| " inputs=P((\"dp\", \"tp\"), None), targets=P((\"dp\", \"tp\"), None)\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "# checkpoint all non-all-gather ops activations in memory, rematerialize all-gathered weights during back propagation\n", | |
| "@partial(jax.checkpoint, policy=lambda p, *_, **__: p.name != \"all_gather\")\n", | |
| "def linear_fsdp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"All-gather weights across devices to device\"\"\"\n", | |
| " w = jax.lax.all_gather(layer.w, (\"dp\", \"tp\"), axis=0, tiled=True)\n", | |
| " b = jax.lax.all_gather(layer.b, \"tp\", axis=0, tiled=True)\n", | |
| " return x @ w + b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_fsdp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16 @ (dp=4, tp=2), 32] -> f32[2{dp,tp}, 32]\n", | |
| " # w, b: all-gathered before local computation: x @ w + b -> y f32[2{dp,tp}, 4]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_fsdp)\n", | |
| " # local per-example MSE loss\n", | |
| " loss_per_example = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " # local mean MSE loss\n", | |
| " loss_local = jnp.mean(loss_per_example)\n", | |
| " # all reduce mean across dp devices -> total mean mse loss\n", | |
| " return jax.lax.pmean(loss_local, (\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "loss_fsdp_, grads_fsdp_ = loss_fsdp(model, batch)\n", | |
| "print(\"FSDP Loss:\", loss_fsdp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_fsdp_, grads_fsdp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "24b3a8bc-8e5f-4dd6-9826-84af1e65bfb8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "3.21 ms ± 344 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_fsdp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "63199c5d-4283-47be-a02b-8695d39dfa0a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FSDP Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# This is also doing FSDP but with a different sharding spec for model weights\n", | |
| "# Shard model weights along dim=0 over dp axis and along dim=1 over tp axis\n", | |
| "linear_pspec: LinearPSpec = Linear(\n", | |
| " w=P(\"dp\", \"tp\"),\n", | |
| " b=P(\"tp\"), # can also be P() to replicate bias since its cheap\n", | |
| ")\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along batch dim over dp x tp mesh axis. Same as DP\n", | |
| "batch_pspec: BatchPSpec = Batch(\n", | |
| " inputs=P((\"dp\", \"tp\"), None), targets=P((\"dp\", \"tp\"), None)\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "# checkpoint all non-all-gather ops activations in memory, rematerialize all-gathered weights during back propagation\n", | |
| "@partial(jax.checkpoint, policy=lambda p, *_, **__: p.name != \"all_gather\")\n", | |
| "def linear_fsdp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"All-gather weights across devices to device\"\"\"\n", | |
| " w_partial = jax.lax.all_gather(layer.w, \"tp\", axis=1, tiled=True)\n", | |
| " w = jax.lax.all_gather(w_partial, \"dp\", axis=0, tiled=True)\n", | |
| " b = jax.lax.all_gather(layer.b, \"tp\", axis=0, tiled=True)\n", | |
| " return x @ w + b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_fsdp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16 @ (dp=4, tp=2), 32] -> f32[2{dp,tp}, 32]\n", | |
| " # w, b: all-gathered before local computation: x @ w + b -> y f32[2{dp,tp}, 4]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_fsdp)\n", | |
| " # local per-example MSE loss\n", | |
| " loss_per_example = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " # local mean MSE loss\n", | |
| " loss_local = jnp.mean(loss_per_example)\n", | |
| " # all reduce mean across dp devices -> total mean mse loss\n", | |
| " return jax.lax.pmean(loss_local, (\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "loss_fsdp_, grads_fsdp_ = loss_fsdp(model, batch)\n", | |
| "print(\"FSDP Loss:\", loss_fsdp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_fsdp_, grads_fsdp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "id": "fa598b1d-164b-4a9b-bc23-664ed23b4e3f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.63 ms ± 159 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_fsdp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0095bc89-ef43-4ef4-bbd9-506779c3c9f8", | |
| "metadata": {}, | |
| "source": [ | |
| "#### TP\n", | |
| "Shard inputs (e.g. data) along the column dim over the `tp` mesh axis, shard the weights along the row dim over the `tp` mesh axis.\n", | |
| "\n", | |
| "Do column @ row matrix multiplication across the `tp` devices and all-reduce-sum partial outputs. If there are further matmul needed, scatter the outputs along column dim over `tp` mesh axis. \n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "61af0444-7eed-408f-86fd-74877db9d1b0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "TP Loss: 4.0431542\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Shard model weights row-wise over tp axis\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(\"tp\", None), b=P(\"tp\"))\n", | |
| "# for b, can also use P() to replicate across tp since it is cheap\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along feature (column) dim over tp mesh axis.\n", | |
| "batch_pspec: BatchPSpec = Batch(inputs=P(None, \"tp\"), targets=P(None, \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "def linear_tp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"Local column x row matmul\"\"\"\n", | |
| " matmul_local = x @ layer.w\n", | |
| " matmul_reduce_scattered = jax.lax.psum_scatter(\n", | |
| " matmul_local, \"tp\", scatter_dimension=1, tiled=True\n", | |
| " )\n", | |
| " return matmul_reduce_scattered + layer.b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@partial(jax.value_and_grad, argnums=0)\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_tp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16, 32@tp] -> f32[16, 16{tp=2}]\n", | |
| " # w, b: sharded as rows over tp: x @ w -> col x row -> matmul_partial -> psum\n", | |
| " # -> scatter along column dim over tp + b (replicated or sharded over tp) -> y f32[16, 4@tp]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_tp)\n", | |
| " loss_local = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " loss_per_example = jax.lax.psum(loss_local, \"tp\") # f32[16, ]\n", | |
| " return jnp.mean(loss_per_example)\n", | |
| "\n", | |
| "\n", | |
| "loss_tp_, grads_tp_ = loss_tp(model, batch)\n", | |
| "print(\"TP Loss:\", loss_tp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_tp_, grads_tp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "5b569900-0f23-4bbe-baed-6b5969d7ba7f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "811 μs ± 31.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_tp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b6ba5afb-7572-40a0-be8a-5ead0cd0b681", | |
| "metadata": {}, | |
| "source": [ | |
| "#### FSDP + TP" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "id": "3bfb53b2-0979-4347-9b88-17a041f79590", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FSDP_TP Loss: 4.0431542\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Shard weights for FSDP + TP\n", | |
| "linear_pspec: LinearPSpec = Linear(\n", | |
| " w=P(\n", | |
| " (\"tp\", \"dp\"),\n", | |
| " ),\n", | |
| " b=P(\"tp\"),\n", | |
| ")\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data for FSDP + TP.\n", | |
| "batch_pspec: BatchPSpec = Batch(inputs=P(\"dp\", \"tp\"), targets=P(\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "@partial(jax.checkpoint, policy=lambda p, *_, **__: p.name != \"all_gather\")\n", | |
| "def linear_fsdp_tp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " w = jax.lax.all_gather(layer.w, \"dp\", axis=0, tiled=True)\n", | |
| " # Note^: if your w partion is P(\"tp\", \"dp\") rather than P((\"tp\", \"dp\")), you need to all-gather along axis=1. Both works\n", | |
| " # if you sharded b across dp, like P(('tp',\"dp\")) you also need to all gather b along dp axis and concate along axis=0\n", | |
| " b = layer.b\n", | |
| " matmul_local = x @ w\n", | |
| " matmul_local_reduce_scattered = jax.lax.psum_scatter(\n", | |
| " matmul_local, \"tp\", scatter_dimension=1, tiled=True\n", | |
| " )\n", | |
| " return matmul_local_reduce_scattered + b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_fsdp_tp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16@dp, 32@tp] -> f32[4{dp=4}, 16{tp=2}]\n", | |
| " # w: sharded as rows over tp x dp: all-gather over dp and concate along axis=0 -> x[@tp, ...]\n", | |
| " # x @ w -> col x row -> reduce_scatter along dim=1 -> matmul@[..., @tp] + b\n", | |
| " # (replicated or sharded over tp) -> y f32[4{dp=4}, 4@tp]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_fsdp_tp)\n", | |
| " # same as tp\n", | |
| " loss_local = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " loss_per_example = jax.lax.psum(loss_local, \"tp\") # f32[16@dp, ] = f32[4{dp=4}]\n", | |
| " return jax.lax.pmean(jnp.mean(loss_per_example), axis_name=\"dp\")\n", | |
| "\n", | |
| "\n", | |
| "loss_fsdp_tp_, grads_fsdp_tp_ = loss_fsdp_tp(model, batch)\n", | |
| "print(\"FSDP_TP Loss:\", loss_fsdp_tp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_fsdp_tp_, grads_fsdp_tp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "640d3c98-b757-4fde-8369-cd1ccd86e527", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.01 ms ± 95.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_fsdp_tp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "57a08331-15db-483c-b41c-fdd025759fe3", | |
| "metadata": {}, | |
| "source": [ | |
| "In practice, you can have a device mesh `(\"dp\", \"fsdp\", \"tp\", \"ep\")`. \n", | |
| "\n", | |
| "You could shard your data using `P((\"dp\", \"fsdp\"), \"tp\")` and shard your Linear layer weights using `P(\"tp\", \"fsdp\")`. During forward, you need to all-gather weights along `fsdp` mesh axis in to the right `dim=1`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2de60ea3-f57d-4efe-85a3-3e3d8e3c85ee", | |
| "metadata": {}, | |
| "source": [ | |
| "#### MOE Parallel with Token Dropping\n", | |
| "Due to Jax [Issue#34168](https://github.com/jax-ml/jax/issues/34168), this code does not run on GPU. `ragged_dot` is not well supported on GPU. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "fb8d3b5c-3155-48ce-878c-6e3093aa53cd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\"\"\"Mixture of Experts (MoE) Layer with token dropping\n", | |
| "\n", | |
| "Using Ragged All-to-All Communication and Ragged Dot in JAX.\n", | |
| "\"\"\"\n", | |
| "\n", | |
| "__author__ = \"Liutong Zhou\"\n", | |
| "\n", | |
| "from __future__ import annotations\n", | |
| "\n", | |
| "from dataclasses import dataclass, field\n", | |
| "from typing import Self\n", | |
| "\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "from jax.sharding import Mesh, PartitionSpec as P\n", | |
| "from jaxtyping import Array, Bool, Float, Int\n", | |
| "\n", | |
| "\n", | |
| "def calculate_offsets(sizes: Int[Array, \"n\"]) -> Int[Array, \"n\"]:\n", | |
| " \"\"\"Calculate chunk start offsets given chunk sizes\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " sizes : Int[Array, \"n\"]\n", | |
| " An array of chunk sizes\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " offsets : Int[Array, \"n\"]\n", | |
| " An array of offsets starting at 0 such that offsets[i] is the starting index\n", | |
| " of chunk i in a flattened array.\n", | |
| "\n", | |
| " Examples\n", | |
| " -----------\n", | |
| " >>> sizes = jnp.array([2, 3, 1])\n", | |
| " >>> calculate_offsets(sizes)\n", | |
| " Array([0, 2, 5], dtype=int32)\n", | |
| " \"\"\"\n", | |
| " sizes_i32 = jnp.asarray(sizes, dtype=jnp.int32)\n", | |
| " # Pad with 0 at the start, remove the last element, then cumsum\n", | |
| " return jnp.cumsum(jnp.pad(sizes_i32[:-1], (1, 0)))\n", | |
| "\n", | |
| "\n", | |
| "def calculate_drop_token_mask(\n", | |
| " token_expert_ids: Int[Array, \"n_tokens\"],\n", | |
| " token_expert_probs: Float[Array, \"n_tokens\"],\n", | |
| " num_experts: int | Int[Array, \"\"],\n", | |
| " expert_capacity: int | Int[Array, \"\"],\n", | |
| ") -> Bool[Array, \"n_tokens\"]:\n", | |
| " \"\"\"Calculate a boolean mask to indicate which tokens should be dropped based on expert capacity.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " token_expert_ids : Int[Array, \"n_tokens\"]\n", | |
| " 1-D array of expert IDs assigned to each token.\n", | |
| " token_expert_probs : Float[Array, \"n_tokens\"]\n", | |
| " 1-D array of probabilities assigned to each token for its expert.\n", | |
| " num_experts : int\n", | |
| " Total number of experts.\n", | |
| " expert_capacity : int\n", | |
| " Maximum number of tokens each expert can handle.\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " drop_mask : Bool[Array, \"n_tokens\"]\n", | |
| " A boolean mask indicating which tokens should be dropped (True) or kept (False).\n", | |
| " \"\"\"\n", | |
| " # Calculate each token's within-expert rank\n", | |
| " # Sort token experts by expert ID (ascending) and then by probability (descending)\n", | |
| " # so that after sorting, tokens for expert 0 are at the top, followed by tokens for expert 1, etc.\n", | |
| " index_to_sorted = jnp.lexsort((-token_expert_probs, token_expert_ids))\n", | |
| " token_expert_ids_sorted_by_rank = token_expert_ids[index_to_sorted]\n", | |
| "\n", | |
| " # Vectorized within-expert rank calculation\n", | |
| " # Find where each expert segment starts in token_expert_ids_sorted_by_rank\n", | |
| " expert_load_size = jnp.bincount(token_expert_ids, length=num_experts)\n", | |
| " expert_start_offsets = calculate_offsets(expert_load_size)\n", | |
| " # token absolute index - its expert start offset gives within-expert rank\n", | |
| " token_within_expert_rank = (\n", | |
| " jnp.arange(token_expert_ids_sorted_by_rank.shape[0])\n", | |
| " - expert_start_offsets[token_expert_ids_sorted_by_rank]\n", | |
| " )\n", | |
| "\n", | |
| " # Token dropping: cap the rank at expert_capacity\n", | |
| " is_dropped_sorted = token_within_expert_rank >= expert_capacity\n", | |
| "\n", | |
| " # revert the sorting to get the original token order\n", | |
| " index_to_inverse_sorted = jnp.argsort(index_to_sorted)\n", | |
| " drop_mask = is_dropped_sorted[index_to_inverse_sorted]\n", | |
| " return drop_mask\n", | |
| "\n", | |
| "\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass(frozen=True, slots=True)\n", | |
| "class MOERaggedDispatcher:\n", | |
| " \"\"\"Orchestrating the MOE token dispatching and returning across devices using ragged all-to-all communications\n", | |
| "\n", | |
| " On initialization, this class pre-calculates the array layout required to move variable-sized\n", | |
| " slices of data between devices. It computes where data should be read from\n", | |
| " (input offsets) and where it should be written to (output offsets) for both\n", | |
| " the forward dispatch and the backward return trip.\n", | |
| "\n", | |
| " Attributes\n", | |
| " ----------\n", | |
| " send_sizes : Int[Array, \"devices\"]\n", | |
| " Number of items this device sends to each peer device.\n", | |
| " input_offsets : Int[Array, \"devices\"]\n", | |
| " Local array offsets to read data from during send.\n", | |
| " recv_sizes : Int[Array, \"devices\"]\n", | |
| " Number of items this device receives from each peer device.\n", | |
| " recv_offsets : Int[Array, \"devices\"]\n", | |
| " Local array offsets where received data will be stored.\n", | |
| " fwd_remote_output_offsets : Int[Array, \"devices\"]\n", | |
| " The offsets on the *receiver* devices where our sent data should be written.\n", | |
| " bwd_remote_output_offsets : Int[Array, \"devices\"]\n", | |
| " The offsets on the *original sender* devices where the returned results\n", | |
| " should be written.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " # Forward Pass Layout\n", | |
| " send_sizes: Int[Array, \"devices\"] # how much to read locally\n", | |
| " input_offsets: Int[Array, \"devices\"] # where to read from locally\n", | |
| " recv_sizes: Int[Array, \"devices\"] # how much to write locally\n", | |
| " recv_offsets: Int[Array, \"devices\"] # where to write locally\n", | |
| "\n", | |
| " # Remote Writing Instructions\n", | |
| " fwd_remote_output_offsets: Int[Array, \"devices\"] # where to write remotely\n", | |
| " # where to write remotely on return trip\n", | |
| " bwd_remote_output_offsets: Int[Array, \"devices\"]\n", | |
| "\n", | |
| " axis_name: str | tuple[str, ...] = field(\n", | |
| " default=\"ep\", metadata={\"static\": True}\n", | |
| " ) # mark as static for JAX jit\n", | |
| "\n", | |
| " @property\n", | |
| " def num_devices(self) -> int:\n", | |
| " \"\"\"Number of devices involved in all-to-all communication.\"\"\"\n", | |
| " return jax.lax.axis_size(self.axis_name)\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def from_target_device_ids(\n", | |
| " cls,\n", | |
| " target_device_ids: Int[Array, \"n\"],\n", | |
| " *,\n", | |
| " axis_name: str | tuple[str, ...],\n", | |
| " mask: Bool[Array, \"n\"] | None = None,\n", | |
| " ) -> Self:\n", | |
| " \"\"\"Create a globally consistent communication plan.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " target_device_ids : Int[Array, \"n\"]\n", | |
| " 1-D array indicating the target device ID for each item to be sent.\n", | |
| " axis_name : str | tuple[str, ...]\n", | |
| " Device mesh axis name(s) along which to perform all-to-all communication.\n", | |
| " mask : Bool[Array, \"n\"], optional\n", | |
| " Optional boolean mask indicating which items to send (True) or drop (False).\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " MOERaggedDispatcher\n", | |
| " An instance of MOERaggedDispatcher with precomputed communication layout.\n", | |
| " \"\"\"\n", | |
| " num_devices = jax.lax.axis_size(axis_name)\n", | |
| " device_send_load = jnp.bincount(target_device_ids, length=num_devices)\n", | |
| " input_offsets = calculate_offsets(device_send_load)\n", | |
| " # Only send this much, drop the masked-out items\n", | |
| " device_send_load_actual = jnp.bincount(\n", | |
| " target_device_ids,\n", | |
| " weights=mask.astype(target_device_ids.dtype) if mask is not None else None,\n", | |
| " length=num_devices,\n", | |
| " )\n", | |
| "\n", | |
| " # 1. Exchange sizes: Tell peer devices how much this device is sending;\n", | |
| " # receive how much they are sending to this device\n", | |
| " recv_sizes = jax.lax.all_to_all(\n", | |
| " device_send_load_actual,\n", | |
| " axis_name=axis_name,\n", | |
| " split_axis=0,\n", | |
| " concat_axis=0,\n", | |
| " tiled=True,\n", | |
| " )\n", | |
| "\n", | |
| " recv_offsets = calculate_offsets(recv_sizes)\n", | |
| "\n", | |
| " # 2. Exchange Output Offsets:\n", | |
| " # a) Forward: Tell original senders where to write in receivers' buffer.\n", | |
| " fwd_remote_output_offsets = jax.lax.all_to_all(\n", | |
| " recv_offsets,\n", | |
| " axis_name=axis_name,\n", | |
| " split_axis=0,\n", | |
| " concat_axis=0,\n", | |
| " tiled=True,\n", | |
| " )\n", | |
| "\n", | |
| " # b) Backward: Tell senders (expert devices) where to write back in receivers' (original senders') buffer.\n", | |
| " # When the expert forward is done, they need to return data to the original device. Tell\n", | |
| " # expert devices to write it exactly where tokens were originally read from.\n", | |
| " bwd_remote_output_offsets = jax.lax.all_to_all(\n", | |
| " input_offsets,\n", | |
| " axis_name=axis_name,\n", | |
| " split_axis=0,\n", | |
| " concat_axis=0,\n", | |
| " tiled=True,\n", | |
| " )\n", | |
| "\n", | |
| " return cls(\n", | |
| " send_sizes=device_send_load_actual,\n", | |
| " input_offsets=input_offsets,\n", | |
| " recv_sizes=recv_sizes,\n", | |
| " recv_offsets=recv_offsets,\n", | |
| " fwd_remote_output_offsets=fwd_remote_output_offsets,\n", | |
| " bwd_remote_output_offsets=bwd_remote_output_offsets,\n", | |
| " axis_name=axis_name,\n", | |
| " )\n", | |
| "\n", | |
| " def dispatch_forward[T: Array](\n", | |
| " self, data: T, capacity: int | Int[Array, \"\"]\n", | |
| " ) -> tuple[T, Bool[Array, \"n\"]]:\n", | |
| " \"\"\"Send data to expert devices.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " data : Array\n", | |
| " The input data (e.g. tokens) sorted by target device. If with masking, data must be\n", | |
| " sorted such that dropped items are at the end within each device group.\n", | |
| " capacity : int\n", | |
| " The total size of the receiver's buffer (must be sufficient for worst case).\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " received : Array\n", | |
| " The data received from peer devices, packed contiguously in output_buffer.\n", | |
| " Shape (capacity, ...).\n", | |
| " is_valid : Bool[Array, \"n\"]\n", | |
| " 1-D boolean mask indicating which received items are valid in the output.\n", | |
| " received data beyond the actual received size are invalid and should be ignored.\n", | |
| " \"\"\"\n", | |
| " # pre-allocate an output buffer of static size for receiving data\n", | |
| " output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype)\n", | |
| " # each device sends data and returns the received data\n", | |
| " received = jax.lax.ragged_all_to_all(\n", | |
| " data,\n", | |
| " output_buffer, # for storing what this device will receive after forward dispatch\n", | |
| " self.input_offsets, # Read from here (local)\n", | |
| " self.send_sizes, # Read this amount (local)\n", | |
| " self.fwd_remote_output_offsets, # Write to here (remote)\n", | |
| " self.recv_sizes, # Write this amount (remote)\n", | |
| " axis_name=self.axis_name,\n", | |
| " )\n", | |
| " is_valid = jnp.arange(capacity) < jnp.sum(self.recv_sizes)\n", | |
| " return received, is_valid\n", | |
| "\n", | |
| " def dispatch_backward[T: Array](\n", | |
| " self, data: T, capacity: int | Int[Array, \"\"]\n", | |
| " ) -> tuple[T, Bool[Array, \"n\"]]:\n", | |
| " \"\"\"Return processed data to the original sender devices.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " data : Array\n", | |
| " The processed data (must be in the same order as received).\n", | |
| " capacity : int\n", | |
| " The size of the buffer on the original sender (to restore shape).\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " returned_to_original_sender : Array\n", | |
| " The results, placed back into their original slots on the sender.\n", | |
| " is_valid : Bool[Array, \"n\"]\n", | |
| " 1-D boolean mask indicating which returned items are valid in the output.\n", | |
| " returned data beyond the actual returned size are invalid and should be ignored.\n", | |
| " \"\"\"\n", | |
| " output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype)\n", | |
| "\n", | |
| " # Note: Roles of input/output offsets are effectively swapped for the return trip\n", | |
| " returned_to_original_sender = jax.lax.ragged_all_to_all(\n", | |
| " data,\n", | |
| " output_buffer,\n", | |
| " self.recv_offsets, # Read from here (local processed data)\n", | |
| " self.recv_sizes, # Read this amount\n", | |
| " self.bwd_remote_output_offsets, # Write to here (remote original sender)\n", | |
| " self.send_sizes, # Write this amount\n", | |
| " axis_name=self.axis_name,\n", | |
| " )\n", | |
| " is_valid = jnp.arange(capacity) < jnp.sum(self.send_sizes)\n", | |
| " return returned_to_original_sender, is_valid" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "b2b86bd6-a7b6-4cc3-a196-444b88dcc5b1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def moe_layer(\n", | |
| " sequences: Float[Array, \"batch seq hidden\"],\n", | |
| " router_weights: Float[Array, \"hidden experts_total\"],\n", | |
| " expert_weights: Float[Array, \"experts_total hidden hidden\"],\n", | |
| " *,\n", | |
| " mesh: Mesh,\n", | |
| " top_k: int,\n", | |
| " expert_axis_name: str | tuple[str, ...] = \"ep\",\n", | |
| " capacity_factor: float = 1.2,\n", | |
| ") -> Float[Array, \"batch seq hidden\"]:\n", | |
| " \"\"\"Execute a Shard-Mapped Mixture of Experts layer with Ragged All-to-All Communication.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " sequences : Float[Array, \"batch seq hidden\"]\n", | |
| " Input token sequences.\n", | |
| " router_weights : Float[Array, \"hidden experts_total\"]\n", | |
| " Weights for the gating network.\n", | |
| " expert_weights : Float[Array, \"experts_total hidden hidden\"]\n", | |
| " Weights for the experts (sharded).\n", | |
| " mesh : Mesh\n", | |
| " The JAX device mesh.\n", | |
| " top_k : int\n", | |
| " Number of experts to select per token.\n", | |
| " expert_axis_name : str | tuple[str, ...], optional\n", | |
| " Name of the mesh axis for experts, by default \"ep\".\n", | |
| " capacity_factor : float, optional\n", | |
| " Multiplier for expert capacity, by default 1.2.\n", | |
| " Expert Capacity = (total_tokens * top_k / total_experts) * capacity_factor.\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " Float[Array, \"batch seq hidden\"]\n", | |
| " The processed sequences.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " @jax.jit\n", | |
| " @jax.shard_map(\n", | |
| " mesh=mesh,\n", | |
| " in_specs=(\n", | |
| " P(\"dp\", expert_axis_name, None), # sequences[batch@dp, seq@ep, hidden]\n", | |
| " P(None, None), # router_weights[hidden, experts_total]\n", | |
| " # expert_weights[experts_total@ep, hidden, hidden]\n", | |
| " P(expert_axis_name, None, None),\n", | |
| " ),\n", | |
| " out_specs=P(\"dp\", expert_axis_name, None),\n", | |
| " )\n", | |
| " def _sharded_moe_impl(\n", | |
| " x_shard: Float[Array, \"local_batch local_seq hidden\"],\n", | |
| " router_weights: Float[Array, \"hidden experts_total\"],\n", | |
| " local_expert_weights: Float[Array, \"local_experts hidden hidden\"],\n", | |
| " ) -> Float[Array, \"local_batch local_seq hidden\"]:\n", | |
| " # x_shard[local_batch{dp}, local_seq{ep}, hidden]\n", | |
| "\n", | |
| " # Setup and preparations\n", | |
| " num_devices_ep = jax.lax.axis_size(expert_axis_name)\n", | |
| " num_experts_per_device = local_expert_weights.shape[0]\n", | |
| " num_total_experts = router_weights.shape[-1]\n", | |
| " assert (\n", | |
| " num_total_experts == num_devices_ep * num_experts_per_device\n", | |
| " ), f\"{num_total_experts=} must equal {num_devices_ep=} * {num_experts_per_device=}\"\n", | |
| "\n", | |
| " # A. Calculating Token Routing\n", | |
| " # 1. Flatten Local Batch: (b, s, hidden) -> (b*s, hidden)\n", | |
| " tokens = x_shard.reshape(-1, x_shard.shape[-1])\n", | |
| "\n", | |
| " # Average tokens per expert = Total Tokens to route / Total Experts\n", | |
| " expert_capacity = jnp.ceil(\n", | |
| " tokens.shape[0]\n", | |
| " * top_k\n", | |
| " * num_devices_ep\n", | |
| " / num_total_experts\n", | |
| " * capacity_factor\n", | |
| " )\n", | |
| " device_capacity = expert_capacity * num_experts_per_device\n", | |
| "\n", | |
| " # 2. Routing (Top-K): (b*s, hidden) @ (hidden, experts_total) -> (b*s, experts_total) -> topk -> (b*s, k)\n", | |
| " top_k_logits, top_k_expert_ids = jax.lax.top_k(tokens @ router_weights, k=top_k)\n", | |
| " top_k_expert_probs = jax.nn.softmax(top_k_logits, axis=-1)\n", | |
| "\n", | |
| " # 3. Expand: (b*s, hidden) -> (b*s*k, hidden)\n", | |
| " expert_ids_flat = top_k_expert_ids.ravel() # (b*s, k) -> (b*s*k,)\n", | |
| " expert_probs_flat = top_k_expert_probs.ravel() # (b*s, k) -> (b*s*k,)\n", | |
| " # Repeat interleave each token k times to create dispatch tokens\n", | |
| " tokens_flat = jnp.repeat(tokens, repeats=top_k, axis=0) # (b*s*k, hidden)\n", | |
| " # Track original token segment IDs to sum k results back later: (b*s*k,)\n", | |
| " token_segment_ids = jnp.repeat(\n", | |
| " jnp.arange(tokens.shape[0]), repeats=top_k, axis=0\n", | |
| " )\n", | |
| "\n", | |
| " # B. Token Dispatching with Token Dropping on senders' side\n", | |
| " is_token_dropped = calculate_drop_token_mask(\n", | |
| " token_expert_ids=expert_ids_flat,\n", | |
| " token_expert_probs=expert_probs_flat,\n", | |
| " num_experts=num_total_experts,\n", | |
| " expert_capacity=expert_capacity,\n", | |
| " )\n", | |
| " # Prepare data (e.g. tokens_flat) to be dispacthed by sorting so that after sorting, data to send are grouped by\n", | |
| " # 1. Target Device (ascending)\n", | |
| " # 2. Is Dropped? Dropped tokens go to the end within each device group (so that they are not sent)\n", | |
| " # 3. Local Expert ID within Device\n", | |
| " target_device_ids, target_local_expert_ids = jnp.divmod(\n", | |
| " expert_ids_flat, num_experts_per_device\n", | |
| " )\n", | |
| " index_to_sorted_for_dispatch = jnp.lexsort(\n", | |
| " (target_local_expert_ids, is_token_dropped, target_device_ids)\n", | |
| " )\n", | |
| "\n", | |
| " tokens_sorted_for_dispatch = tokens_flat[index_to_sorted_for_dispatch]\n", | |
| " expert_probs_sorted_for_dispatch = expert_probs_flat[\n", | |
| " index_to_sorted_for_dispatch\n", | |
| " ]\n", | |
| " target_local_expert_ids_sorted_for_dispatch = target_local_expert_ids[\n", | |
| " index_to_sorted_for_dispatch\n", | |
| " ]\n", | |
| " token_segment_ids_sorted_for_dispatch = token_segment_ids[\n", | |
| " index_to_sorted_for_dispatch\n", | |
| " ]\n", | |
| " # Initialize Device Dispatcher\n", | |
| " dispatcher = MOERaggedDispatcher.from_target_device_ids(\n", | |
| " target_device_ids, axis_name=expert_axis_name, mask=~is_token_dropped\n", | |
| " )\n", | |
| " # Dispatch Forward: Send tokens to expert devices, capping expert capacity on senders' side\n", | |
| " # In worst case, a receiver device may receive full expert load from all senders, need to provision receiver's capacity for that\n", | |
| " # note: device_capacity * num_devices_ep is constant wrt number of devices. As we increase devices, it does not increase.\n", | |
| " receiver_capacity = device_capacity * num_devices_ep\n", | |
| " recv_tokens, _ = dispatcher.dispatch_forward(\n", | |
| " tokens_sorted_for_dispatch, receiver_capacity\n", | |
| " )\n", | |
| " recv_probs, _ = dispatcher.dispatch_forward(\n", | |
| " expert_probs_sorted_for_dispatch, receiver_capacity\n", | |
| " )\n", | |
| " recv_local_expert_ids, is_valid = dispatcher.dispatch_forward(\n", | |
| " target_local_expert_ids_sorted_for_dispatch, receiver_capacity\n", | |
| " )\n", | |
| "\n", | |
| " # C. Expert Computation\n", | |
| " # Group tokens by local expert ids (0 to experts_per_device-1) for ragged dot\n", | |
| " # Note: we move invalid tokens (0s from output buffer) to the end so they are ignored in computation\n", | |
| " safe_local_expert_ids = jnp.where(\n", | |
| " is_valid, recv_local_expert_ids, num_experts_per_device\n", | |
| " )\n", | |
| " indices_to_sort_by_local_expert = jnp.argsort(safe_local_expert_ids)\n", | |
| " recv_tokens_by_local_expert = recv_tokens[indices_to_sort_by_local_expert]\n", | |
| " recv_probs_by_local_expert = recv_probs[indices_to_sort_by_local_expert]\n", | |
| " # Calculate batch size per local expert for ragged_dot\n", | |
| " local_expert_token_count = jnp.bincount(\n", | |
| " safe_local_expert_ids,\n", | |
| " # Effectively mask out invalid tokens\n", | |
| " weights=is_valid.astype(safe_local_expert_ids.dtype),\n", | |
| " length=num_experts_per_device,\n", | |
| " )\n", | |
| " # Perform MOE computation: (Batch, Hidden) @ (Hidden, Hidden) for each expert group\n", | |
| " # Note: expert_outputs.shape[0] = recv_tokens_by_local_expert.shape[0] with\n", | |
| " # expert_outputs[sum(local_expert_token_count):] = 0\n", | |
| " expert_outputs = jax.lax.ragged_dot(\n", | |
| " recv_tokens_by_local_expert,\n", | |
| " local_expert_weights,\n", | |
| " group_sizes=local_expert_token_count,\n", | |
| " )\n", | |
| " # Apply gating weights\n", | |
| " expert_outputs = expert_outputs * recv_probs_by_local_expert[:, None]\n", | |
| "\n", | |
| " # D. Dispatch Back (Return Trip)\n", | |
| " # Restore Order (Undo local expert sort)\n", | |
| " indices_invert_sort_by_local_expert = jnp.argsort(\n", | |
| " indices_to_sort_by_local_expert\n", | |
| " )\n", | |
| " packed_outputs = expert_outputs[indices_invert_sort_by_local_expert]\n", | |
| " tokens_moe_back, _ = dispatcher.dispatch_backward(\n", | |
| " packed_outputs, capacity=tokens_sorted_for_dispatch.shape[0]\n", | |
| " )\n", | |
| "\n", | |
| " # Finally. Summing (Segment Sum) expert outputs\n", | |
| " # Sum the K contributions back into the original token slots using the segment IDs.\n", | |
| " # (b*s*k, hidden) -> (b*s, hidden) corresponding to tokens due to usage of segment_ids\n", | |
| " combined_output = jax.ops.segment_sum(\n", | |
| " tokens_moe_back,\n", | |
| " segment_ids=token_segment_ids_sorted_for_dispatch,\n", | |
| " num_segments=tokens.shape[0],\n", | |
| " )\n", | |
| "\n", | |
| " return combined_output.reshape(x_shard.shape)\n", | |
| "\n", | |
| " return _sharded_moe_impl(sequences, router_weights, expert_weights)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d9acfd28-6e97-4391-b052-c946d7afb86e", | |
| "metadata": {}, | |
| "source": [ | |
| "### Knowledge Distillation\n", | |
| "We split the available devices into two disjoint meshes:\n", | |
| "* `mesh_inference`: Runs Teacher Model Forward Pass only\n", | |
| "* `mesh_train`: Runs Student Model Training (Forward + Backward + Update)\n", | |
| " \n", | |
| "Because JAX uses **Asynchronous Dispatch**, the Python for loop body in the code below will not be blocking. \n", | |
| "JAX will enqueue `teacher_step(Batch_2)` on `mesh_inference` while `mesh_train` is still busy calculating gradients for `student_step(Batch_1)`, thus **automatically creating a pipeline where both meshes work simultaneously**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 55, | |
| "id": "d9abd854-27da-42f9-841a-89a1d16389d7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "devices = jax.devices()\n", | |
| "# Split devices: First 4 for Teacher, Last 4 for Student (Adjust based on your hardware and model's GPU memory load)\n", | |
| "split_idx = len(devices) // 2\n", | |
| "\n", | |
| "# Tensor Parallel for big model inference\n", | |
| "mesh_inference = Mesh(devices[:split_idx], axis_names=(\"tp\",))\n", | |
| "\n", | |
| "# Data Parallel for small model training\n", | |
| "mesh_train = Mesh(devices[split_idx:], axis_names=(\"dp\",))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "id": "94acbd52-9613-4d2b-8a71-039a01f8d2e3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Initialize models on separate device meshes\n", | |
| "with jax.set_mesh(mesh_inference):\n", | |
| " key_model_teacher, key_batch = jax.random.split(jax.random.key(0), num=2)\n", | |
| " model_teacher = Model(layer_sizes=(32, 128, 256, 128, 4), key=key_model_teacher)\n", | |
| " # init some fake data\n", | |
| " batch = model_teacher.init_random_batch(key=key_batch, batch_size=16)\n", | |
| "with jax.set_mesh(mesh_train):\n", | |
| " model_student = Model(layer_sizes=(32, 4), key=jax.random.key(1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "1d02b631-7b80-4773-8d60-3c659203dcc6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Shard weights & input data for TP parallel inference\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(\"tp\", None), b=P(\"tp\"))\n", | |
| "model_teacher_pspec = Model(\n", | |
| " layer_sizes=model_teacher.layer_sizes,\n", | |
| " layers=[linear_pspec] * model_teacher.num_layers,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=mesh_inference,\n", | |
| " in_specs=(model_teacher_pspec, P(None, \"tp\")),\n", | |
| " out_specs=P(None, \"tp\"), # Output is sharded on 'tp' axis\n", | |
| ")\n", | |
| "def teacher_forward(\n", | |
| " model_sharded: Model, x: Float[Array, \"batch h\"]\n", | |
| ") -> Float[Array, \"batch C\"]:\n", | |
| " \"\"\"Mock Forward pass of teacher. Returns logits\"\"\"\n", | |
| "\n", | |
| " def linear_tp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"Local column x row matmul\"\"\"\n", | |
| " matmul_local = x @ layer.w\n", | |
| " matmul_reduce_scattered = jax.lax.psum_scatter(\n", | |
| " matmul_local, \"tp\", scatter_dimension=1, tiled=True\n", | |
| " )\n", | |
| " return matmul_reduce_scattered + layer.b\n", | |
| "\n", | |
| " return model_sharded.predict(x, linear_fwd_fn=linear_tp)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 58, | |
| "id": "b1bb9342-8a10-4209-83af-5c278a4c69eb", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Data Parallel for training small model: Only shard data\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(), b=P())\n", | |
| "model_student_pspec = Model(\n", | |
| " layer_sizes=model_student.layer_sizes,\n", | |
| " layers=[linear_pspec] * model_student.num_layers,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@jax.value_and_grad\n", | |
| "def kd_loss_dp(\n", | |
| " model_student: Model,\n", | |
| " x: Float[Array, \"batch_local hidden\"],\n", | |
| " y_pred_teacher: Float[Array, \"...\"],\n", | |
| ") -> Float[Array, \"\"]:\n", | |
| " \"\"\"Knowledge Distillation Loss\"\"\"\n", | |
| " y_pred_student = model_student.predict(x, linear_fwd_fn=linear)\n", | |
| " loss_dp_local = jnp.sum((y_pred_student - y_pred_teacher) ** 2, axis=-1).mean()\n", | |
| " return jax.lax.pmean(loss_dp_local, axis_name=\"dp\")\n", | |
| "\n", | |
| "\n", | |
| "@partial(jax.jit, donate_argnames=[\"model_student\"]) # buffer donation\n", | |
| "@jax.shard_map(\n", | |
| " mesh=mesh_train,\n", | |
| " # Only shard data (X,y) for DP\n", | |
| " in_specs=(model_student_pspec, P(\"dp\", None), P(\"dp\", None)),\n", | |
| " out_specs=(model_student_pspec, P()),\n", | |
| ")\n", | |
| "def train_student_step(\n", | |
| " model_student: Model,\n", | |
| " x: Float[Array, \"batch_local hidden\"],\n", | |
| " target_logits: Float[Array, \"batch_local C\"],\n", | |
| ") -> Model:\n", | |
| " loss, grads = kd_loss_dp(model_student, x, target_logits)\n", | |
| " # Update weights\n", | |
| " model_student_updated = jax.tree.map(\n", | |
| " lambda w, w_grad: w - 0.01 * w_grad, model_student, grads\n", | |
| " )\n", | |
| " return model_student_updated, loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 59, | |
| "id": "f42998d0-b788-47c3-b085-296bfb175b03", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loss: 1.09\n", | |
| "Loss: 1.01\n", | |
| "Loss: 0.94\n", | |
| "Loss: 0.88\n", | |
| "Loss: 0.83\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Train Loop\n", | |
| "for i in range(5):\n", | |
| " # Inference on Mesh_Inference\n", | |
| " with jax.set_mesh(mesh_inference):\n", | |
| " y_pred_teacher = teacher_forward(model_teacher, batch.inputs)\n", | |
| "\n", | |
| " # Train on Mesh train\n", | |
| " with jax.set_mesh(mesh_train):\n", | |
| " # Train on Mesh_Train\n", | |
| " # Move data across devices\n", | |
| " y_pred_teacher = jax.device_put(\n", | |
| " y_pred_teacher, NamedSharding(mesh_train, P(\"dp\", None))\n", | |
| " )\n", | |
| " batch_inputs = jax.device_put(\n", | |
| " batch.inputs, NamedSharding(mesh_train, P(\"dp\", None))\n", | |
| " )\n", | |
| " model_student, loss = train_student_step(\n", | |
| " model_student, batch_inputs, y_pred_teacher\n", | |
| " )\n", | |
| " print(f\"Loss: {loss:.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "77fc6ef2-a584-4661-8c20-37a417bac9c0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.13.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment