Last active
January 17, 2026 19:12
-
-
Save mdmitry1/21a76345acf2ea8b823e34e2a9f684eb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/python3.14 | |
| # -*- coding: utf-8 -*- | |
| import torch | |
| import math | |
| from hashlib import sha256 | |
| from matplotlib import pyplot as plt | |
| def main(timeout: int=5000, size: int=100000) -> int: | |
| dtype = torch.float | |
| device = torch.device("cpu") | |
| # device = torch.device("cuda:0") # Uncomment this to run on GPU | |
| # Create random input and output data | |
| x = torch.linspace(0, math.pi/2, size, device=device, dtype=dtype) | |
| y = torch.sin(x) | |
| # Randomly initialize weights | |
| torch.manual_seed(42) | |
| a = torch.randn((), device=device, dtype=dtype) | |
| b = torch.randn((), device=device, dtype=dtype) | |
| c = torch.randn((), device=device, dtype=dtype) | |
| d = torch.randn((), device=device, dtype=dtype) | |
| learning_rate = 1e-6 | |
| print("Iteration RMSE") | |
| for t in range(60000): | |
| # Forward pass: compute predicted y | |
| y_pred = a + b * x + c * x ** 2 + d * x ** 3 | |
| # Compute and print loss | |
| loss = (y_pred - y).pow(2).sum().item() | |
| if t % 1000 == 999: | |
| print(t+1, math.sqrt(loss/size)) | |
| # Backprop to compute gradients of a, b, c, d with respect to loss | |
| grad_y_pred = 2.0 * (y_pred - y) | |
| grad_a = grad_y_pred.sum() | |
| grad_b = (grad_y_pred * x).sum() | |
| grad_c = (grad_y_pred * x ** 2).sum() | |
| grad_d = (grad_y_pred * x ** 3).sum() | |
| # Update weights using gradient descent | |
| a -= learning_rate * grad_a | |
| b -= learning_rate * grad_b | |
| c -= learning_rate * grad_c | |
| d -= learning_rate * grad_d | |
| result = f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3' | |
| print(result) | |
| err = a + b * x + c * x ** 2 + d * x ** 3 - y | |
| plt.plot(x,y,x,a + b * x + c * x ** 2 + d * x ** 3) | |
| plt.grid() | |
| plt.gcf().canvas.manager.set_window_title('SIN(X) PyTorch approximation') | |
| plt.title("Plot Example") | |
| plt.legend(['sin','sin predicted']) | |
| plt.xlabel('x') | |
| plt.ylabel('sin(x) and pytorch prediction') | |
| if math.inf != float(timeout): | |
| timer = plt.gcf().canvas.new_timer(interval=timeout, callbacks=[(plt.close, [], {})]) | |
| timer.start() | |
| plt.show() | |
| print(f"Maximum error = {torch.max(torch.abs((err))).item():.6f}") | |
| return sha256(result.encode()).hexdigest() | |
| if __name__ == "__main__": | |
| print(main(math.inf)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sys | |
| from pytorch_ex import main | |
| def test_pytorch_ex(monkeypatch): | |
| with monkeypatch.context() as m: | |
| m.setattr(sys, 'argv', ['pytorch_ex']) | |
| print("") | |
| assert main() == '92138878ecd663f9ff1e8313be1ffd1fb31d87e57ec64a0883672fe30fa5e2e3' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment