In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path
# I'm just gonna check that you were paying attention ;)
#
# CHECK THE PATHS!
#
for name in ("/content/drive/MyDrive/x_train.wav", "/content/drive/MyDrive/y_train.wav", "/content/drive/MyDrive/x_test.wav", "/content/drive/MyDrive/y_test.wav"):
  if not Path(name).exists():
    raise RuntimeError(f"I didn't find all of your data files. Where is {name}?")

## Step 2: Installation

In [None]:
!pip install "neural-amp-modeler"
#pip install git+https://github.com/38github/neural-amp-modeler.git@lstm

In [None]:
from time import time
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from nam.data import Split, init_dataset
from nam.train.lightning_module import LightningModule
from nam.models.losses import esr

## Step 3: Settings

In [None]:
data_config = {
    "train": {
        "x_path": "drive/MyDrive/x_train.wav",
        "y_path": "drive/MyDrive/y_train.wav",
        "ny": 48000
    },
    "validation": {
        "x_path": "drive/MyDrive/x_test.wav",
        "y_path": "drive/MyDrive/y_test.wav",
        "ny": None
    },
    "common": {
        "delay": int(input("What is the latency (in samples) of your reamp? "))
    },
}
model_config = {
    "net": {
        "name": "LSTM",
        "config": {
            "num_layers": 2,
            "hidden_size": 14,
            "train_burn_in": 64, #!# Much larger originally
            "train_truncate": 20480 #!# Much smaller originally
                  }
           },
    "loss": {
          "val_loss": "mse",
          "mask_first": 1, #!# 4096
#         "pre_emph_weight": 1.0,
#         "pre_emph_coef": 0.90
    },
    "optimizer": {
        "lr": 0.005 #!# 0.004
    },
    "lr_scheduler": {
        "class": "ExponentialLR",
        "kwargs": {
            "gamma": 0.997
        }
    }
}
learning_config = {
    "train_dataloader": {
        "batch_size": 8, #!# 16
        "shuffle": True,
        "pin_memory": True,
        "drop_last": False, #!# True
        "num_workers": 8 #!# 0
    },
    "val_dataloader": {},
    "trainer": {
        "accelerator": "gpu",
        "devices": 1,
        "max_epochs": 1000 #!# 500-650 is usually enough
    }
}

## Step 4: Run!

In [None]:
model = LightningModule.init_from_config(model_config)

In [None]:
data_config["common"]["nx"] = model.net.receptive_field

In [None]:
dataset_train = init_dataset(data_config, Split.TRAIN)
dataset_validation = init_dataset(data_config, Split.VALIDATION)
train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])

In [None]:
trainer = pl.Trainer(
    callbacks=[
        pl.callbacks.model_checkpoint.ModelCheckpoint(
            filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}",
            save_top_k=20,
            monitor="val_loss",
            every_n_epochs=1,
        ),
        pl.callbacks.model_checkpoint.ModelCheckpoint(
            filename="checkpoint_last_{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}", every_n_epochs=1
        ),
    ],
    **learning_config["trainer"],
)

In [None]:
#torch.set_float32_matmul_precision('medium') # highest, high, medium
trainer.fit(model, train_dataloader, val_dataloader)
# Monitor the progress in lightning_logs/version_0/checkpoints.
#
# Many models can get a good result (rule of thumb: look for ESR<0.01) in about 15
# minutes of training, but if you're more patient, it'll probably keep getting better.

In [None]:
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
    model = LightningModule.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path,
        **LightningModule.parse_config(model_config),
    )
model.cpu()
model.eval()

# Step 5: Check

In [None]:
def plot(
    model,
    ds,
    savefig=None,
    show=True,
    window_start: Optional[int] = None,
    window_end: Optional[int] = None,
):
    with torch.no_grad():
        tx = len(ds.x) / 48_000
        print(f"Run (t={tx})")
        t0 = time()
        output = model(ds.x).flatten().cpu().numpy()
        t1 = time()
        print(f"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)")

    plt.figure(figsize=(16, 5))
    plt.plot(output[window_start:window_end], label="Prediction")
    plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
    plt.title(f"ESR={esr(torch.Tensor(output), ds.y):.4f}")
    plt.legend()
    if savefig is not None:
        plt.savefig(savefig)
    if show:
        plt.show()

In [None]:
plot(
    model,
    dataset_validation,
    window_start=100_000,  # Start of the plotting window, in samples
    window_end=101_000,  # End of the plotting window, in samples
)

## Step 6: Export your model
Now we'll use NAM's exporting utility to convert the model from its PyTorch representation to something that you can put into the plugin.

In [None]:
#!# Create folder exported_model in the root of your Google Drive
#!#
#!# Don't forget to open the model.nam with a text editor and at the end change
#!# sample_rate value from null to the sample rate of the training files.
#!# Usually 48000
#!#
#Path("exported_model").mkdir()
model.net.export("/content/drive/MyDrive/exported_model")