{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"-pU4c4pTzmnP"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"R_filL-5F8HR"},"outputs":[],"source":["from pathlib import Path\n","# I'm just gonna check that you were paying attention ;)\n","#\n","# CHECK THE PATHS!\n","#\n","for name in (\"/content/drive/MyDrive/x_train.wav\", \"/content/drive/MyDrive/y_train.wav\", \"/content/drive/MyDrive/x_test.wav\", \"/content/drive/MyDrive/y_test.wav\"):\n"," if not Path(name).exists():\n"," raise RuntimeError(f\"I didn't find all of your data files. Where is {name}?\")"]},{"cell_type":"markdown","metadata":{"id":"2g_4GtFuGlO8"},"source":["## Step 2: Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vYQIpWr5EYRb"},"outputs":[],"source":["!pip install \"neural-amp-modeler\"\n","#pip install git+https://github.com/38github/neural-amp-modeler.git@lstm"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-6GUkLz3EayL"},"outputs":[],"source":["from time import time\n","from typing import Optional, Union\n","\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import pytorch_lightning as pl\n","import torch\n","from torch.utils.data import DataLoader\n","\n","from nam.data import Split, init_dataset\n","from nam.train.lightning_module import LightningModule\n","from nam.models.losses import esr"]},{"cell_type":"markdown","metadata":{"id":"j5fN10s3GwVz"},"source":["## Step 3: Settings"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Y6gl6RoNJ_6I"},"outputs":[],"source":["data_config = {\n"," \"train\": {\n"," \"x_path\": \"drive/MyDrive/x_train.wav\",\n"," \"y_path\": \"drive/MyDrive/y_train.wav\",\n"," \"ny\": 48000\n"," },\n"," \"validation\": {\n"," \"x_path\": \"drive/MyDrive/x_test.wav\",\n"," \"y_path\": \"drive/MyDrive/y_test.wav\",\n"," \"ny\": None\n"," },\n"," \"common\": {\n"," \"delay\": int(input(\"What is the latency (in samples) of your reamp? \"))\n"," },\n","}\n","model_config = {\n"," \"net\": {\n"," \"name\": \"LSTM\",\n"," \"config\": {\n"," \"num_layers\": 2,\n"," \"hidden_size\": 14,\n"," \"train_burn_in\": 64, #!# Much larger originally\n"," \"train_truncate\": 20480 #!# Much smaller originally\n"," }\n"," },\n"," \"loss\": {\n"," \"val_loss\": \"mse\",\n"," \"mask_first\": 1, #!# 4096\n","# \"pre_emph_weight\": 1.0,\n","# \"pre_emph_coef\": 0.90\n"," },\n"," \"optimizer\": {\n"," \"lr\": 0.005 #!# 0.004\n"," },\n"," \"lr_scheduler\": {\n"," \"class\": \"ExponentialLR\",\n"," \"kwargs\": {\n"," \"gamma\": 0.997\n"," }\n"," }\n","}\n","learning_config = {\n"," \"train_dataloader\": {\n"," \"batch_size\": 8, #!# 16\n"," \"shuffle\": True,\n"," \"pin_memory\": True,\n"," \"drop_last\": False, #!# True\n"," \"num_workers\": 8 #!# 0\n"," },\n"," \"val_dataloader\": {},\n"," \"trainer\": {\n"," \"accelerator\": \"gpu\",\n"," \"devices\": 1,\n"," \"max_epochs\": 1000 #!# 500-650 is usually enough\n"," }\n","}"]},{"cell_type":"markdown","metadata":{"id":"pNga-MNTMQAa"},"source":["## Step 4: Run!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"UAChcygdMTF4"},"outputs":[],"source":["model = LightningModule.init_from_config(model_config)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"iZ1jCOh7i7Ct"},"outputs":[],"source":["data_config[\"common\"][\"nx\"] = model.net.receptive_field"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"OV4gLukTMjdD"},"outputs":[],"source":["dataset_train = init_dataset(data_config, Split.TRAIN)\n","dataset_validation = init_dataset(data_config, Split.VALIDATION)\n","train_dataloader = DataLoader(dataset_train, **learning_config[\"train_dataloader\"])\n","val_dataloader = DataLoader(dataset_validation, **learning_config[\"val_dataloader\"])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vyhMf0ZyM4kt"},"outputs":[],"source":["trainer = pl.Trainer(\n"," callbacks=[\n"," pl.callbacks.model_checkpoint.ModelCheckpoint(\n"," filename=\"{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}\",\n"," save_top_k=20,\n"," monitor=\"val_loss\",\n"," every_n_epochs=1,\n"," ),\n"," pl.callbacks.model_checkpoint.ModelCheckpoint(\n"," filename=\"checkpoint_last_{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}\", every_n_epochs=1\n"," ),\n"," ],\n"," **learning_config[\"trainer\"],\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"a8WLIx33M7c6"},"outputs":[],"source":["#torch.set_float32_matmul_precision('medium') # highest, high, medium\n","trainer.fit(model, train_dataloader, val_dataloader)\n","# Monitor the progress in lightning_logs/version_0/checkpoints.\n","#\n","# Many models can get a good result (rule of thumb: look for ESR<0.01) in about 15\n","# minutes of training, but if you're more patient, it'll probably keep getting better."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JzGltwwJNAkI"},"outputs":[],"source":["# Go to best checkpoint\n","best_checkpoint = trainer.checkpoint_callback.best_model_path\n","if best_checkpoint != \"\":\n"," model = LightningModule.load_from_checkpoint(\n"," trainer.checkpoint_callback.best_model_path,\n"," **LightningModule.parse_config(model_config),\n"," )\n","model.cpu()\n","model.eval()"]},{"cell_type":"markdown","metadata":{"id":"QvuJEYxJNGn7"},"source":["# Step 5: Check"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"K0UeoIbaNMxF"},"outputs":[],"source":["def plot(\n"," model,\n"," ds,\n"," savefig=None,\n"," show=True,\n"," window_start: Optional[int] = None,\n"," window_end: Optional[int] = None,\n","):\n"," with torch.no_grad():\n"," tx = len(ds.x) / 48_000\n"," print(f\"Run (t={tx})\")\n"," t0 = time()\n"," output = model(ds.x).flatten().cpu().numpy()\n"," t1 = time()\n"," print(f\"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)\")\n","\n"," plt.figure(figsize=(16, 5))\n"," plt.plot(output[window_start:window_end], label=\"Prediction\")\n"," plt.plot(ds.y[window_start:window_end], linestyle=\"--\", label=\"Target\")\n"," plt.title(f\"ESR={esr(torch.Tensor(output), ds.y):.4f}\")\n"," plt.legend()\n"," if savefig is not None:\n"," plt.savefig(savefig)\n"," if show:\n"," plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"C_NsBdp5NQMC"},"outputs":[],"source":["plot(\n"," model,\n"," dataset_validation,\n"," window_start=100_000, # Start of the plotting window, in samples\n"," window_end=101_000, # End of the plotting window, in samples\n",")"]},{"cell_type":"markdown","metadata":{"id":"R__jJFwgNkAl"},"source":["## Step 6: Export your model\n","Now we'll use NAM's exporting utility to convert the model from its PyTorch representation to something that you can put into the plugin."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yQDcgoi_NrsW"},"outputs":[],"source":["#!# Create folder exported_model in the root of your Google Drive\n","#!#\n","#!# Don't forget to open the model.nam with a text editor and at the end change\n","#!# sample_rate value from null to the sample rate of the training files.\n","#!# Usually 48000\n","#!#\n","#Path(\"exported_model\").mkdir()\n","model.net.export(\"/content/drive/MyDrive/exported_model\")"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[{"file_id":"https://github.com/sdatkinson/neural-amp-modeler/blob/main/bin/train/colab.ipynb","timestamp":1684685514258}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.6"},"vscode":{"interpreter":{"hash":"920df60c69944ba95f8c12adb41fedfdc8090c370a20d39253c7705973dd37db"}}},"nbformat":4,"nbformat_minor":0}