{"cells":[{"cell_type":"markdown","source":["## Upload test and train files to the root directory of your Google Drive"],"metadata":{"id":"lIfm0rRGF5It"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":31247,"status":"ok","timestamp":1732914463946,"user":{"displayName":"Rickard","userId":"13996032449383725839"},"user_tz":-60},"id":"-pU4c4pTzmnP","outputId":"415ea1c0-9ac9-4886-dad5-20d4b76a7ee4"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"R_filL-5F8HR"},"outputs":[],"source":["from pathlib import Path\n","# I'm just gonna check that you were paying attention ;)\n","for name in (\"/content/drive/MyDrive/x_train.wav\", \"/content/drive/MyDrive/y_train.wav\", \"/content/drive/MyDrive/x_test.wav\", \"/content/drive/MyDrive/y_test.wav\"):\n"," if not Path(name).exists():\n"," raise RuntimeError(f\"I didn't find all of your data files. Where is {name}?\")"]},{"cell_type":"markdown","metadata":{"id":"2g_4GtFuGlO8"},"source":["## Step 2: Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vYQIpWr5EYRb"},"outputs":[],"source":["#!pip install \"neural-amp-modeler\"\n","!pip install git+https://github.com/sdatkinson/neural-amp-modeler.git\n","#!pip install git+https://github.com/38github/neural-amp-modeler.git@main"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-6GUkLz3EayL"},"outputs":[],"source":["from time import time\n","from typing import Optional, Union\n","\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import pytorch_lightning as pl\n","import torch\n","from torch.utils.data import DataLoader\n","\n","from nam.data import Split, init_dataset\n","from nam.train.lightning_module import LightningModule\n","from nam.models.losses import esr"]},{"cell_type":"markdown","metadata":{"id":"j5fN10s3GwVz"},"source":["## Step 3: Settings"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3056,"status":"ok","timestamp":1731441266315,"user":{"displayName":"Rickard","userId":"13996032449383725839"},"user_tz":-60},"id":"Y6gl6RoNJ_6I","outputId":"bab87d56-795e-4603-bc07-f62a3c7f8314"},"outputs":[{"name":"stdout","output_type":"stream","text":["What is the latency (in samples) of your reamp? 0\n"]}],"source":["# NANO\n","data_config = {\n"," \"train\": {\n"," \"x_path\": \"drive/MyDrive/x_train.wav\",\n"," \"y_path\": \"drive/MyDrive/y_train.wav\",\n"," \"ny\": 8192 # was 8192\n"," },\n"," \"validation\": {\n"," \"x_path\": \"drive/MyDrive/x_test.wav\",\n"," \"y_path\": \"drive/MyDrive/y_test.wav\",\n"," \"ny\": None\n"," },\n"," \"common\": {\n"," \"delay\": int(input(\"What is the latency (in samples) of your reamp? \"))\n"," }\n","}\n","model_config = {\n"," \"net\": {\n"," \"name\": \"WaveNet\",\n"," # This is a modified version of the \"standard\" model in easy mode / the local GUI trainer.\n"," \"config\": {\n","\"layers_configs\": [\n"," {\n"," \"input_size\": 1,\n"," \"condition_size\": 1,\n"," \"channels\": 4,\n"," \"head_size\": 2,\n"," \"kernel_size\": 3,\n"," \"dilations\": [128,64,32,16,8,4,2,1,24], #!#\n"," \"activation\": \"Hardtanh\", #!#\n"," \"gated\": False, #!# False\n"," \"head_bias\": True, #!# False\n"," },\n"," {\n"," \"condition_size\": 1,\n"," \"input_size\": 4,\n"," \"channels\": 2,\n"," \"head_size\": 1,\n"," \"kernel_size\": 3,\n"," \"dilations\": [512,256,128,64,32,16,8,4,2,1,24], #!#\n"," \"activation\": \"Hardtanh\", #!#\n"," \"gated\": False, #!# False\n"," \"head_bias\": True, #!# True\n"," },\n"," ],\n"," \"head_scale\": 0.08, #!# 0.02\n"," }\n"," },\n"," \"loss\": {\n"," \"val_loss\": \"esr\",\n"," \"pre_emph_mrstft_weight\": 0.00002,\n"," \"pre_emph_mrstft_coef\": 0.85\n"," },\n"," \"optimizer\": {\n"," \"lr\": 0.002, #!#\n"," },\n"," \"lr_scheduler\": {\n"," \"class\": \"ExponentialLR\",\n"," # \"seed\": 42,\n"," \"kwargs\": {\"gamma\": 0.9985}, #!#\n"," }\n"," }\n","\n","learning_config = {\n"," \"train_dataloader\": {\n"," \"batch_size\": 8, # was: 16\n"," \"shuffle\": True, # was: True\n"," \"pin_memory\": False, #!# True <---------------\n"," \"drop_last\": True, #!# True\n"," \"num_workers\": 8 #!# 0\n"," },\n"," \"val_dataloader\": {},\n"," \"trainer\": {\n"," \"accelerator\": \"gpu\",\n"," \"devices\": 1,\n"," \"max_epochs\": 5000\n"," }\n","}"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"SdP7UgX1kOdh","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1731603702354,"user_tz":-60,"elapsed":2582,"user":{"displayName":"Rickard","userId":"13996032449383725839"}},"outputId":"52db57b6-8a5c-4382-c539-365c5c387eb4"},"outputs":[{"name":"stdout","output_type":"stream","text":["What is the latency (in samples) of your reamp? 0\n"]}],"source":["# FEATHER\n","data_config = {\n"," \"train\": {\n"," \"x_path\": \"drive/MyDrive/x_train.wav\",\n"," \"y_path\": \"drive/MyDrive/y_train.wav\",\n"," \"ny\": 8192 # was 8192\n"," },\n"," \"validation\": {\n"," \"x_path\": \"drive/MyDrive/x_test.wav\",\n"," \"y_path\": \"drive/MyDrive/y_test.wav\",\n"," \"ny\": None\n"," },\n"," \"common\": {\n"," \"delay\": int(input(\"What is the latency (in samples) of your reamp? \"))\n"," }\n","}\n","model_config = {\n"," \"net\": {\n"," \"name\": \"WaveNet\",\n"," # This is a modified version of the \"standard\" model in easy mode / the local GUI trainer.\n"," \"config\": {\n","\"layers_configs\": [\n"," {\n"," \"input_size\": 1,\n"," \"condition_size\": 1,\n"," \"channels\": 8,\n"," \"head_size\": 4,\n"," \"kernel_size\": 3,\n"," \"dilations\": [128,64,32,16,8,4,2,1,24], #!#\n"," \"activation\": \"Hardtanh\", #!#\n"," \"gated\": False, #!# False\n"," \"head_bias\": True, #!# False\n"," },\n"," {\n"," \"condition_size\": 1,\n"," \"input_size\": 8,\n"," \"channels\": 4,\n"," \"head_size\": 1,\n"," \"kernel_size\": 3,\n"," \"dilations\": [512,256,128,64,32,16,8,4,2,1,24], #!#\n"," \"activation\": \"Hardtanh\", #!#\n"," \"gated\": False, #!# False\n"," \"head_bias\": True, #!# True\n"," },\n"," ],\n"," \"head_scale\": 0.08, #!# 0.02\n"," }\n"," },\n"," \"loss\": {\n"," \"val_loss\": \"esr\",\n"," \"pre_emph_mrstft_weight\": 0.00002,\n"," \"pre_emph_mrstft_coef\": 0.85\n"," },\n"," \"optimizer\": {\n"," \"lr\": 0.002, #!#\n"," },\n"," \"lr_scheduler\": {\n"," \"class\": \"ExponentialLR\",\n"," # \"seed\": 42,\n"," \"kwargs\": {\"gamma\": 0.9985}, #!#\n"," }\n"," }\n","\n","learning_config = {\n"," \"train_dataloader\": {\n"," \"batch_size\": 16, # was: 16\n"," \"shuffle\": True, # was: True\n"," \"pin_memory\": False, #!# True <-----\n"," \"drop_last\": True, #!# True\n"," \"num_workers\": 8 #!# 0\n"," },\n"," \"val_dataloader\": {},\n"," \"trainer\": {\n"," \"accelerator\": \"gpu\",\n"," \"devices\": 1,\n"," \"max_epochs\": 5000\n"," }\n","}"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fCuIn_215VF0"},"outputs":[],"source":["# REVERSED3 STANDARD\n","data_config = {\n"," \"train\": {\n"," \"x_path\": \"drive/MyDrive/x_train.wav\",\n"," \"y_path\": \"drive/MyDrive/y_train.wav\",\n"," \"ny\": 8192 # was 8192\n"," },\n"," \"validation\": {\n"," \"x_path\": \"drive/MyDrive/x_test.wav\",\n"," \"y_path\": \"drive/MyDrive/y_test.wav\",\n"," \"ny\": None\n"," },\n"," \"common\": {\n"," \"delay\": int(input(\"What is the latency (in samples) of your reamp? \"))\n"," }\n","}\n","model_config = {\n"," \"net\": {\n"," \"name\": \"WaveNet\",\n"," # This is a modified version of the \"standard\" model in easy mode / the local GUI trainer.\n"," \"config\": {\n","\"layers_configs\": [\n"," {\n"," \"input_size\": 1,\n"," \"condition_size\": 1,\n"," \"channels\": 16,\n"," \"head_size\": 8,\n"," \"kernel_size\": 3,\n"," \"dilations\": [512,256,128,64,32,16,8,4,2,1,24], #!#\n"," \"activation\": \"Hardtanh\", #!#\n"," \"gated\": False, #!# False\n"," \"head_bias\": True, #!# False\n"," },\n"," {\n"," \"condition_size\": 1,\n"," \"input_size\": 16,\n"," \"channels\": 8,\n"," \"head_size\": 1,\n"," \"kernel_size\": 3,\n"," \"dilations\": [512,256,128,64,32,16,8,4,2,1,24], #!#\n"," \"activation\": \"Hardtanh\", #!#\n"," \"gated\": False, #!# False\n"," \"head_bias\": True, #!# True <-----\n"," },\n"," ],\n"," \"head_scale\": 0.08, #!# 0.02\n"," }\n"," },\n"," \"loss\": {\n"," \"val_loss\": \"esr\",\n"," \"pre_emph_mrstft_weight\": 0.00002,\n"," \"pre_emph_mrstft_coef\": 0.85\n"," },\n"," \"optimizer\": {\n"," \"lr\": 0.002, #!#\n"," },\n"," \"lr_scheduler\": {\n"," \"class\": \"ExponentialLR\",\n"," \"kwargs\": {\"gamma\": 0.9985}, #!#\n"," }\n"," }\n","\n","learning_config = {\n"," \"train_dataloader\": {\n"," \"batch_size\": 16, # was: 16\n"," \"shuffle\": True, # was: True\n"," \"pin_memory\": False, #!# True <-----\n"," \"drop_last\": False, #!# True <-----\n"," \"num_workers\": 0\n"," },\n"," \"val_dataloader\": {},\n"," \"trainer\": {\n"," \"accelerator\": \"gpu\",\n"," \"devices\": 1,\n"," \"max_epochs\": 5000\n"," }\n","}"]},{"cell_type":"markdown","metadata":{"id":"pNga-MNTMQAa"},"source":["## Step 4: Run!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"UAChcygdMTF4"},"outputs":[],"source":["model = LightningModule.init_from_config(model_config)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"iZ1jCOh7i7Ct"},"outputs":[],"source":["data_config[\"common\"][\"nx\"] = model.net.receptive_field"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"OV4gLukTMjdD"},"outputs":[],"source":["dataset_train = init_dataset(data_config, Split.TRAIN)\n","dataset_validation = init_dataset(data_config, Split.VALIDATION)\n","train_dataloader = DataLoader(dataset_train, **learning_config[\"train_dataloader\"])\n","val_dataloader = DataLoader(dataset_validation, **learning_config[\"val_dataloader\"])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vyhMf0ZyM4kt"},"outputs":[],"source":["trainer = pl.Trainer(\n"," callbacks=[\n"," pl.callbacks.model_checkpoint.ModelCheckpoint(\n"," filename=\"{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}\",\n"," save_top_k=20,\n"," monitor=\"val_loss\",\n"," every_n_epochs=1,\n"," ),\n"," pl.callbacks.model_checkpoint.ModelCheckpoint(\n"," filename=\"checkpoint_last_{epoch:04d}_{step}\", every_n_epochs=1\n"," ),\n"," ],\n"," **learning_config[\"trainer\"],\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"a8WLIx33M7c6"},"outputs":[],"source":["trainer.fit(model, train_dataloader, val_dataloader)\n","# Monitor the progress in lightning_logs/version_0/checkpoints.\n","#\n","# Many models can get a good result (rule of thumb: look for ESR<0.01) in about 15\n","# minutes of training, but if you're more patient, it'll probably keep getting better."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JzGltwwJNAkI"},"outputs":[],"source":["# Go to best checkpoint\n","best_checkpoint = trainer.checkpoint_callback.best_model_path\n","if best_checkpoint != \"\":\n"," model = LightningModule.load_from_checkpoint(\n"," trainer.checkpoint_callback.best_model_path,\n"," **LightningModule.parse_config(model_config),\n"," )\n","model.cpu()\n","model.eval()"]},{"cell_type":"markdown","metadata":{"id":"QvuJEYxJNGn7"},"source":["# Step 5: Check\n","Let's look at how well our model matches the real thing."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"K0UeoIbaNMxF"},"outputs":[],"source":["def plot(\n"," model,\n"," ds,\n"," savefig=None,\n"," show=True,\n"," window_start: Optional[int] = None,\n"," window_end: Optional[int] = None,\n","):\n"," with torch.no_grad():\n"," tx = len(ds.x) / 48_000\n"," print(f\"Run (t={tx})\")\n"," t0 = time()\n"," output = model(ds.x).flatten().cpu().numpy()\n"," t1 = time()\n"," print(f\"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)\")\n","\n"," plt.figure(figsize=(16, 5))\n"," plt.plot(output[window_start:window_end], label=\"Prediction\")\n"," plt.plot(ds.y[window_start:window_end], linestyle=\"--\", label=\"Target\")\n"," plt.title(f\"ESR={esr(torch.Tensor(output), ds.y):.6f}\")\n"," plt.legend()\n"," if savefig is not None:\n"," plt.savefig(savefig)\n"," if show:\n"," plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"C_NsBdp5NQMC"},"outputs":[],"source":["plot(\n"," model,\n"," dataset_validation,\n"," window_start=100_000, # Start of the plotting window, in samples\n"," window_end=101_000, # End of the plotting window, in samples\n",")"]},{"cell_type":"markdown","metadata":{"id":"R__jJFwgNkAl"},"source":["## Step 6: Export your model\n","# Create a folder called exported_model in the root of your Google Drive\n","Now we'll use NAM's exporting utility to convert the model from its PyTorch representation to something that you can put into the plugin."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yQDcgoi_NrsW"},"outputs":[],"source":["#Path(\"exported_model\").mkdir()\n","model.net.export(\"drive/MyDrive/exported_model\")"]},{"cell_type":"markdown","metadata":{"id":"823KJ_L0Rchp"},"source":["## Step 7: Download your artifacts\n","We're done!\n","Go to the file browser on the left panel ⬅ and download `model.nam` from the `exported_model` directory (you may need to hit the refresh button).\n","\n","Additionally, if you want to continue to train this model later you can download the lightning model artifacts from `lightning_logs`. If not, that's fine.\n","\n","# 🎸 **ENJOY!** 🎸"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[{"file_id":"https://github.com/sdatkinson/neural-amp-modeler/blob/main/bin/train/colab.ipynb","timestamp":1684685514258}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.6"},"vscode":{"interpreter":{"hash":"920df60c69944ba95f8c12adb41fedfdc8090c370a20d39253c7705973dd37db"}}},"nbformat":4,"nbformat_minor":0}