Last active
January 8, 2021 14:43
-
-
Save tcwalther/f1f2a31a2f2fba3e8f2fa3ea99164002 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np\n", | |
| "import scipy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# RFFT Tests" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_fft_length_matches_input_size():\n", | |
| " input = np.array([1, 2, 3, 4, 3, 8, 6, 3, 5, 2, 7, 6, 9, 5, 8, 3]).reshape((4, 4))\n", | |
| " \n", | |
| " result = tf.signal.rfft2d(input, (4,4)).numpy()\n", | |
| " \n", | |
| " expected_result = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n", | |
| " -15, -2+13j, -5, -10-5j, 3-6j, -6-11j])\n", | |
| " np.testing.assert_array_equal(result.reshape(-1), expected_result)\n", | |
| " \n", | |
| "test_fft_length_matches_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_fft_length_smaller_than_input_size():\n", | |
| " input = np.array([1, 2, 3, 4, 0, 3, 8, 6, 3, 0, 5, 2, 7, 6, 0, 9, 5, 8, 3, 0]).reshape((4, 5))\n", | |
| " \n", | |
| " np_result = np.fft.rfft2(input, (4, 4))\n", | |
| " tf_result = tf.signal.rfft2d(input, fft_length=(4, 4)).numpy()\n", | |
| " \n", | |
| " expected_result = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n", | |
| " -15, -2+13j, -5, -10-5j, 3-6j, -6-11j])\n", | |
| " \n", | |
| " np.testing.assert_array_almost_equal(np_result.reshape(-1), expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result.reshape(-1), expected_result)\n", | |
| "\n", | |
| "test_fft_length_smaller_than_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_fft_length_greater_than_input_size():\n", | |
| " input = np.array([[1, 2, 3, 4],\n", | |
| " [3, 8, 6, 3],\n", | |
| " [5, 2, 7, 6]])\n", | |
| " \n", | |
| " np_result = np.fft.rfft2(input, (4, 8)).reshape(-1)\n", | |
| " tf_result = tf.signal.rfft2d(input, fft_length=(4, 8)).numpy().reshape(-1)\n", | |
| " \n", | |
| " expected_result = np.array([\n", | |
| " 50, 8.29289341-33.6776695j, -7+1j, 9.70710659-1.67766953j, 0,\n", | |
| " -10-20j, -16.3639603-1.12132037j, -5+1j, -7.19238806-2.05025244j, -6+2j,\n", | |
| " 10, -4.7781744-6.12132025j, -1+11j, 10.7781744+1.87867963j, 4,\n", | |
| " -10+20j, 11.1923885+11.9497471j, 5-5j, -3.63603902-3.12132025j, -6-2j])\n", | |
| "\n", | |
| " np.testing.assert_array_almost_equal(np_result, expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result, expected_result)\n", | |
| " \n", | |
| "test_fft_length_greater_than_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_input_dims_greater_than_2():\n", | |
| " input = np.array([1, 2, 3, 4, 3, 8, 6, 3, 5, 2, 7, 6, 7, 3, 23, 5]).reshape((2, 2, 4))\n", | |
| " \n", | |
| " np_result = np.fft.rfft2(input, (2, 4)).reshape(-1)\n", | |
| " tf_result = tf.signal.rfft2d(input, fft_length=(2, 4)).numpy().reshape(-1)\n", | |
| "\n", | |
| " expected_result = np.array([\n", | |
| " 30, -5-3j, -4,\n", | |
| " -10, 1+7j, 0,\n", | |
| " 58, -18+6j, 26,\n", | |
| " -18, 14+2j, -18\n", | |
| " ])\n", | |
| "\n", | |
| " np.testing.assert_array_almost_equal(np_result, expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result, expected_result)\n", | |
| "\n", | |
| "test_input_dims_greater_than_2()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# IRFFT Tests" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_fft_length_matches_input_size():\n", | |
| " input = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n", | |
| " -15, -2+13j, -5, -10-5j, 3-6j, -6-11j]).reshape(4, 3)\n", | |
| " \n", | |
| " np_result = np.fft.irfft2(input, (4, 4)).reshape(-1)\n", | |
| " tf_result = tf.signal.irfft2d(input, fft_length=(4, 4)).numpy().reshape(-1)\n", | |
| " \n", | |
| " expected_result = np.array([1, 2, 3, 4, 3, 8, 6, 3, 5, 2, 7, 6, 9, 5, 8, 3])\n", | |
| " \n", | |
| " np.testing.assert_array_almost_equal(np_result, expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result, expected_result)\n", | |
| " \n", | |
| "test_fft_length_matches_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(4, 3)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def test_fft_length_smaller_than_input_size():\n", | |
| " input = np.array([1, 2, 3, 4, 0, 3, 8, 6, 3, 0, 5, 2, 7, 6, 0, 9, 5, 8, 3, 0]).reshape((4, 5))\n", | |
| " \n", | |
| " np_result = np.fft.rfft2(input, (4, 4))\n", | |
| " tf_result = tf.signal.rfft2d(input, fft_length=(4, 4)).numpy()\n", | |
| " print(np_result.shape)\n", | |
| " \n", | |
| " expected_result = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n", | |
| " -15, -2+13j, -5, -10-5j, 3-6j, -6-11j])\n", | |
| " \n", | |
| " np.testing.assert_array_almost_equal(np_result.reshape(-1), expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result.reshape(-1), expected_result)\n", | |
| "\n", | |
| "test_fft_length_smaller_than_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_fft_length_smaller_than_input_size():\n", | |
| " input = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n", | |
| " -15, -2+13j, -5, -10-5j, 3-6j, -6-11j]).reshape(4, 3)\n", | |
| " \n", | |
| " np_result = np.fft.irfft2(input, (2, 2))\n", | |
| " tf_result = tf.signal.irfft2d(input, fft_length=(2, 2)).numpy()\n", | |
| " \n", | |
| " expected_result = np.array([14, 18.5,\n", | |
| " 20.5, 22]).reshape(2, 2)\n", | |
| " \n", | |
| " np.testing.assert_array_almost_equal(np_result, expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result, expected_result)\n", | |
| " \n", | |
| "test_fft_length_smaller_than_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_fft_length_greater_than_input_size():\n", | |
| " input = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j, -15, -2+13j, -5, -10-5j, 3-6j, -6-11j]).reshape((4, 3))\n", | |
| " \n", | |
| " np_result = np.fft.irfft2(input, (4, 8))\n", | |
| " tf_result = tf.signal.irfft2d(input, fft_length=(4, 8)).numpy()\n", | |
| " \n", | |
| " expected_result = np.array([[0.25, 0.54289322, 1.25, 1.25, 1.25, 1.95710678, 2.25, 1.25],\n", | |
| " [1.25, 2.85355339, 4.25, 3.91421356, 2.75, 2.14644661, 1.75, 1.08578644],\n", | |
| " [3., 1.43933983, 0.5, 2.14644661, 4., 3.56066017, 2.5, 2.85355339],\n", | |
| " [5.625, 3.65533009, 1.375, 3.3017767, 5.125, 2.59466991, 0.375, 2.9482233]])\n", | |
| " \n", | |
| " np.testing.assert_array_almost_equal(np_result, expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result, expected_result)\n", | |
| " \n", | |
| "test_fft_length_greater_than_input_size()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def test_input_dims_greater_than_2():\n", | |
| " input = np.array([30, -5-3j, -4,\n", | |
| " -10, 1+7j, 0,\n", | |
| " 58, -18+6j, 26,\n", | |
| " -18, 14+2j, -18]).reshape(2, 2, 3)\n", | |
| " \n", | |
| " np_result = np.fft.irfft2(input, (2, 4))\n", | |
| " tf_result = tf.signal.irfft2d(input, fft_length=(2, 4)).numpy()\n", | |
| " \n", | |
| " expected_result = np.array([1., 2., 3., 4., 3., 8., 6., 3.,\n", | |
| " 5., 2., 7., 6., 7., 3., 23., 5.]).reshape(2, 2, 4)\n", | |
| "\n", | |
| " np.testing.assert_array_almost_equal(np_result, expected_result)\n", | |
| " np.testing.assert_array_almost_equal(tf_result, expected_result)\n", | |
| " \n", | |
| "test_input_dims_greater_than_2()" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python [conda env:tflite]", | |
| "language": "python", | |
| "name": "conda-env-tflite-py" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment