import%20marimo%0A%0A__generated_with%20%3D%20%220.13.7%22%0Aapp%20%3D%20marimo.App(width%3D%22full%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%23%20CIFAR%20Demonstration%0A%0A%20%20%20%20This%20notebook%20demonstrates%20how%20to%20use%20the%20%60hierarchicalsoftmax%60%20module%20to%20train%20a%20neural%20network%20on%20the%20%5BCIFAR%5D(https%3A%2F%2Fwww.cs.toronto.edu%2F~kriz%2Fcifar.html)%20dataset.%0A%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22First%2C%20choose%20the%20hyperparameters.%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20cifar_radio%20%3D%20mo.ui.radio(options%3D%5B%2210%22%2C%22100%22%5D%2C%20value%3Dmo.cli_args().get(%22cifar%22)%20or%20%22100%22%2C%20label%3D%22CIFAR%20Dataset%22)%0A%20%20%20%20batch_size_input%20%3D%20mo.ui.number(value%3Dmo.cli_args().get(%22batch%22)%20or%2032%2C%20label%3D%22Batch%20Size%22)%0A%20%20%20%20epochs_input%20%3D%20mo.ui.number(value%3Dmo.cli_args().get(%22batch%22)%20or%2010%2C%20label%3D%22Epochs%22)%0A%20%20%20%20mo.vstack(%5Bcifar_radio%2C%20epochs_input%2C%20batch_size_input%5D)%0A%20%20%20%20return%20batch_size_input%2C%20cifar_radio%2C%20epochs_input%0A%0A%0A%40app.cell%0Adef%20_(batch_size_input%2C%20cifar_radio%2C%20epochs_input)%3A%0A%20%20%20%20from%20torchvision%20import%20datasets%2C%20transforms%0A%20%20%20%20from%20torch.utils.data%20import%20DataLoader%0A%0A%20%20%20%20assert%20cifar_radio.value%20in%20%5B%2210%22%2C%22100%22%5D%0A%20%20%20%20batch_size%20%3D%20batch_size_input.value%0A%20%20%20%20epochs%20%3D%20epochs_input.value%0A%20%20%20%20cifar_dataset%20%3D%20datasets.CIFAR10%20if%20cifar_radio.value%20%3D%3D%20%2210%22%20else%20datasets.CIFAR100%0A%0A%20%20%20%20%23%20Use%20the%20same%20data%20augmentation%20strategies%20as%20in%20https%3A%2F%2Farxiv.org%2Fpdf%2F1605.07146v4%0A%20%20%20%20transform%20%3D%20transforms.Compose(%5B%0A%20%20%20%20%20%20%20%20transforms.RandomCrop(32%2C%20padding%3D4%2C%20padding_mode%3D%22reflect%22)%2C%0A%20%20%20%20%20%20%20%20transforms.RandomHorizontalFlip()%2C%0A%20%20%20%20%20%20%20%20transforms.ToTensor()%2C%0A%20%20%20%20%5D)%0A%0A%20%20%20%20train_data%20%3D%20cifar_dataset(root%3D%22.%22%2C%20train%3DTrue%2C%20download%3DTrue%2C%20transform%3Dtransform)%0A%20%20%20%20test_data%20%3D%20cifar_dataset(root%3D%22.%22%2C%20train%3DFalse%2C%20download%3DTrue%2C%20transform%3Dtransforms.ToTensor())%0A%0A%20%20%20%20train_loader%20%3D%20DataLoader(train_data%2C%20batch_size%3Dbatch_size%2C%20shuffle%3DTrue)%0A%20%20%20%20test_loader%20%3D%20DataLoader(test_data%2C%20batch_size%3Dbatch_size%2C%20shuffle%3DFalse)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20DataLoader%2C%0A%20%20%20%20%20%20%20%20batch_size%2C%0A%20%20%20%20%20%20%20%20epochs%2C%0A%20%20%20%20%20%20%20%20test_data%2C%0A%20%20%20%20%20%20%20%20test_loader%2C%0A%20%20%20%20%20%20%20%20train_data%2C%0A%20%20%20%20%20%20%20%20train_loader%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%23%23%20Plot%20the%20first%2010%20images%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(train_data)%3A%0A%20%20%20%20import%20plotly.graph_objects%20as%20go%0A%20%20%20%20from%20plotly.subplots%20import%20make_subplots%0A%0A%20%20%20%20num_images%20%3D%2010%0A%0A%20%20%20%20%23%20Create%20a%20row%20of%20subplots%0A%20%20%20%20cifar_fig%20%3D%20make_subplots(%0A%20%20%20%20%20%20%20%20rows%3D1%2C%20cols%3Dnum_images%2C%20%0A%20%20%20%20%20%20%20%20subplot_titles%3D%5Btrain_data.classes%5Btrain_data%5Bi%5D%5B1%5D%5D%20for%20i%20in%20range(num_images)%5D%2C%20%0A%20%20%20%20%20%20%20%20horizontal_spacing%3D0%2C%0A%20%20%20%20)%0A%0A%20%20%20%20for%20i%20in%20range(num_images)%3A%0A%20%20%20%20%20%20%20%20img%2C%20label%20%3D%20train_data%5Bi%5D%0A%20%20%20%20%20%20%20%20img%20%3D%20img.permute(1%2C%202%2C%200).numpy()%20%20%23%20(C%2C%20H%2C%20W)%20-%3E%20(H%2C%20W%2C%20C)%20and%20convert%20to%20numpy%0A%0A%20%20%20%20%20%20%20%20cifar_fig.add_trace(%0A%20%20%20%20%20%20%20%20%20%20%20%20go.Image(z%3D(img%20*%20255).astype('uint8'))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20row%3D1%2C%20col%3Di%2B1%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%23%20Update%20layout%3A%20remove%20axes%20and%20tighten%20spacing%0A%20%20%20%20thumbnail_size%20%3D%20105%0A%20%20%20%20cifar_fig.update_layout(%0A%20%20%20%20%20%20%20%20height%3Dthumbnail_size%2C%20%20%23%20adjust%20height%20as%20needed%0A%20%20%20%20%20%20%20%20width%3Dthumbnail_size%20*%20num_images%2C%20%20%23%20150px%20per%20image%0A%20%20%20%20%20%20%20%20showlegend%3DFalse%2C%0A%20%20%20%20%20%20%20%20margin%3Ddict(l%3D0%2C%20r%3D0%2C%20t%3D30%2C%20b%3D0)%0A%20%20%20%20)%0A%0A%20%20%20%20%23%20Hide%20axes%0A%20%20%20%20for%20i%20in%20range(1%2C%20num_images%20%2B%201)%3A%0A%20%20%20%20%20%20%20%20cifar_fig.update_xaxes(visible%3DFalse%2C%20row%3D1%2C%20col%3Di)%0A%20%20%20%20%20%20%20%20cifar_fig.update_yaxes(visible%3DFalse%2C%20row%3D1%2C%20col%3Di)%0A%0A%20%20%20%20cifar_fig%0A%20%20%20%20return%20(go%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20r%22%22%22%0A%20%20%20%20%23%23%20Non-hierarchical%20model%0A%0A%20%20%20%20First%20we%20create%20a%20basic%20non-hierarchical%20model%20as%20a%20baseline%0A%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(train_data)%3A%0A%20%20%20%20import%20torch%0A%20%20%20%20from%20torch%20import%20nn%0A%20%20%20%20from%20torchmetrics%20import%20Accuracy%0A%20%20%20%20import%20lightning%20as%20L%0A%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%20%20%20%20import%20torch.nn.functional%20as%20F%0A%0A%20%20%20%20class%20BasicBlock(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20in_planes%2C%20out_planes%2C%20stride)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.bn1%20%3D%20nn.BatchNorm2d(in_planes)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.conv1%20%3D%20nn.Conv2d(in_planes%2C%20out_planes%2C%20kernel_size%3D3%2C%20stride%3Dstride%2C%20padding%3D1%2C%20bias%3DFalse)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.bn2%20%3D%20nn.BatchNorm2d(out_planes)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.conv2%20%3D%20nn.Conv2d(out_planes%2C%20out_planes%2C%20kernel_size%3D3%2C%20stride%3D1%2C%20padding%3D1%2C%20bias%3DFalse)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.shortcut%20%3D%20nn.Sequential()%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20stride%20!%3D%201%20or%20in_planes%20!%3D%20out_planes%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20self.shortcut%20%3D%20nn.Conv2d(in_planes%2C%20out_planes%2C%20kernel_size%3D1%2C%20stride%3Dstride%2C%20bias%3DFalse)%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.conv1(F.relu(self.bn1(x)))%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.conv2(F.relu(self.bn2(out)))%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%2B%3D%20self.shortcut(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20out%0A%0A%0A%20%20%20%20class%20WideResNetBody(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20depth%3D16%2C%20width_factor%3D8)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20assert%20(depth%20-%204)%20%25%206%20%3D%3D%200%2C%20%22Depth%20should%20be%206n%2B4%22%0A%20%20%20%20%20%20%20%20%20%20%20%20n%20%3D%20(depth%20-%204)%20%2F%2F%206%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20k%20%3D%20width_factor%0A%20%20%20%20%20%20%20%20%20%20%20%20self.in_planes%20%3D%2016%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Initial%20conv%0A%20%20%20%20%20%20%20%20%20%20%20%20self.conv1%20%3D%20nn.Conv2d(3%2C%2016%2C%20kernel_size%3D3%2C%20stride%3D1%2C%20padding%3D1%2C%20bias%3DFalse)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%203%20groups%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer1%20%3D%20self._make_layer(16*k%2C%20n%2C%20stride%3D1)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer2%20%3D%20self._make_layer(32*k%2C%20n%2C%20stride%3D2)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layer3%20%3D%20self._make_layer(64*k%2C%20n%2C%20stride%3D2)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.bn%20%3D%20nn.BatchNorm2d(64*k)%0A%0A%20%20%20%20%20%20%20%20def%20_make_layer(self%2C%20out_planes%2C%20blocks%2C%20stride)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20strides%20%3D%20%5Bstride%5D%20%2B%20%5B1%5D*(blocks-1)%0A%20%20%20%20%20%20%20%20%20%20%20%20layers%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20s%20in%20strides%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20layers.append(BasicBlock(self.in_planes%2C%20out_planes%2C%20s))%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20self.in_planes%20%3D%20out_planes%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20nn.Sequential(*layers)%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.conv1(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.layer1(out)%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.layer2(out)%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.layer3(out)%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20F.relu(self.bn(out))%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20F.avg_pool2d(out%2C%208)%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20out.view(out.size(0)%2C%20-1)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20out%0A%0A%0A%20%20%20%20class%20BasicImageClassifier(L.LightningModule)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.model%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20WideResNetBody()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.LazyLinear(out_features%3Dlen(train_data.classes))%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.loss_fn%20%3D%20nn.CrossEntropyLoss()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.metrics%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20Accuracy(task%3D%22multiclass%22%2C%20num_classes%3Dlen(train_data.classes))%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.model(x)%0A%0A%20%20%20%20%20%20%20%20def%20training_step(self%2C%20batch%2C%20batch_idx)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%2C%20y%20%3D%20batch%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20self(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20self.loss_fn(logits%2C%20y)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.log('train_loss'%2C%20loss%2C%20prog_bar%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20loss%0A%0A%20%20%20%20%20%20%20%20def%20validation_step(self%2C%20batch%2C%20batch_idx)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20x%2C%20y%20%3D%20batch%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20self(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20self.loss_fn(logits%2C%20y)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.log('val_loss'%2C%20loss%2C%20prog_bar%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20metric%20in%20self.metrics%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20metric%20%3D%20metric.to(logits.device)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20result%20%3D%20metric(logits%2C%20y)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20isinstance(result%2C%20dict)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20name%2C%20value%20in%20result.items()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20self.log(f%22val_%7Bname%7D%22%2C%20value%2C%20on_step%3DFalse%2C%20on_epoch%3DTrue%2C%20prog_bar%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20self.log(f%22val_%7Bmetric.__class__.__name__%7D%22%2C%20result%2C%20on_step%3DFalse%2C%20on_epoch%3DTrue%2C%20prog_bar%3DTrue)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20loss%0A%0A%20%20%20%20%20%20%20%20def%20configure_optimizers(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20torch.optim.Adam(self.parameters()%2C%20lr%3D1e-3)%0A%0A%20%20%20%20basic_model%20%3D%20BasicImageClassifier()%0A%20%20%20%20basic_model%0A%20%20%20%20return%20BasicImageClassifier%2C%20L%2C%20basic_model%2C%20nn%2C%20torch%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%20Train%20the%20basic%20model%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(L%2C%20basic_model%2C%20epochs%2C%20test_loader%2C%20train_loader)%3A%0A%20%20%20%20from%20lightning.pytorch.loggers%20import%20CSVLogger%0A%0A%20%20%20%20basic_logger%20%3D%20CSVLogger(save_dir%3D%22lightning_logs%22%2C%20name%3D%22basic_model%22)%0A%20%20%20%20basic_trainer%20%3D%20L.Trainer(max_epochs%3Depochs%2C%20accelerator%3D%22auto%22%2C%20enable_checkpointing%3DFalse%2C%20logger%3Dbasic_logger)%0A%20%20%20%20basic_trainer.fit(basic_model%2C%20train_dataloaders%3Dtrain_loader%2C%20val_dataloaders%3Dtest_loader)%0A%20%20%20%20return%20CSVLogger%2C%20basic_logger%0A%0A%0A%40app.cell%0Adef%20_(basic_logger%2C%20go)%3A%0A%20%20%20%20import%20pandas%20as%20pd%0A%20%20%20%20from%20pathlib%20import%20Path%0A%0A%20%20%20%20basic_metrics_df%20%3D%20pd.read_csv(Path(basic_logger.log_dir)%20%2F%20%22metrics.csv%22)%0A%20%20%20%20basic_metrics_df%20%3D%20basic_metrics_df.dropna(subset%3D%5B%22val_MulticlassAccuracy%22%5D)%0A%20%20%20%20basic_fig%20%3D%20go.Figure()%0A%20%20%20%20basic_fig.add_trace(go.Scatter(x%3Dbasic_metrics_df%5B%22epoch%22%5D%2C%20y%3Dbasic_metrics_df%5B%22val_MulticlassAccuracy%22%5D%2C%20mode%3D'lines'%2C%20name%3D'class'))%0A%20%20%20%20basic_fig.update_layout(%0A%20%20%20%20%20%20%20%20xaxis_title%3D%22Epochs%22%2C%0A%20%20%20%20%20%20%20%20yaxis_title%3D%22Accuracy%22%2C%0A%20%20%20%20)%0A%20%20%20%20basic_fig.show()%0A%0A%20%20%20%20return%20Path%2C%20pd%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20%23%23%20Hierarchical%20model%0A%0A%20%20%20%20Let's%20now%20create%20a%20hierarchical%20model.%0A%20%20%20%20First%20we%20need%20to%20create%20a%20tree%20structure%20for%20the%20CIFAR%20dataset.%0A%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo%2C%20train_data)%3A%0A%20%20%20%20from%20hierarchicalsoftmax%20import%20(%0A%20%20%20%20%20%20%20%20SoftmaxNode%2C%0A%20%20%20%20%20%20%20%20HierarchicalSoftmaxLazyLinear%2C%0A%20%20%20%20%20%20%20%20HierarchicalSoftmaxLoss%2C%0A%20%20%20%20)%0A%20%20%20%20from%20hierarchicalsoftmax.metrics%20import%20RankAccuracyTorchMetric%0A%0A%20%20%20%20if%20len(train_data.classes)%20%3D%3D%2010%3A%0A%20%20%20%20%20%20%20%20%23%20CIFAR-10%0A%20%20%20%20%20%20%20%20superclasses%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22animals%22%3A%20%5B%22bird%22%2C%20%22cat%22%2C%20%22deer%22%2C%20%22dog%22%2C%20%22frog%22%2C%20%22horse%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22vehicles%22%3A%20%5B%22airplane%22%2C%20%22automobile%22%2C%20%22ship%22%2C%20%22truck%22%5D%2C%0A%20%20%20%20%20%20%20%20%7D%0A%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%23%20CIFAR-100%0A%20%20%20%20%20%20%20%20superclasses%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22aquatic%20mammals%22%3A%20%5B%22beaver%22%2C%20%22dolphin%22%2C%20%22otter%22%2C%20%22seal%22%2C%20%22whale%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22fish%22%3A%20%5B%22aquarium_fish%22%2C%20%22flatfish%22%2C%20%22ray%22%2C%20%22shark%22%2C%20%22trout%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22flowers%22%3A%20%5B%22orchid%22%2C%20%22poppy%22%2C%20%22rose%22%2C%20%22sunflower%22%2C%20%22tulip%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22food%20containers%22%3A%20%5B%22bottle%22%2C%20%22bowl%22%2C%20%22can%22%2C%20%22cup%22%2C%20%22plate%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22fruit%20and%20vegetables%22%3A%20%5B%22apple%22%2C%20%22mushroom%22%2C%20%22orange%22%2C%20%22pear%22%2C%20%22sweet_pepper%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22household%20electrical%20devices%22%3A%20%5B%22clock%22%2C%20%22keyboard%22%2C%20%22lamp%22%2C%20%22telephone%22%2C%20%22television%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22household%20furniture%22%3A%20%5B%22bed%22%2C%20%22chair%22%2C%20%22couch%22%2C%20%22table%22%2C%20%22wardrobe%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22insects%22%3A%20%5B%22bee%22%2C%20%22beetle%22%2C%20%22butterfly%22%2C%20%22caterpillar%22%2C%20%22cockroach%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22large%20carnivores%22%3A%20%5B%22bear%22%2C%20%22leopard%22%2C%20%22lion%22%2C%20%22tiger%22%2C%20%22wolf%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22large%20man-made%20outdoor%20things%22%3A%20%5B%22bridge%22%2C%20%22castle%22%2C%20%22house%22%2C%20%22road%22%2C%20%22skyscraper%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22large%20natural%20outdoor%20scenes%22%3A%20%5B%22cloud%22%2C%20%22forest%22%2C%20%22mountain%22%2C%20%22plain%22%2C%20%22sea%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22large%20omnivores%20and%20herbivores%22%3A%20%5B%22camel%22%2C%20%22cattle%22%2C%20%22chimpanzee%22%2C%20%22elephant%22%2C%20%22kangaroo%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22medium-sized%20mammals%22%3A%20%5B%22fox%22%2C%20%22porcupine%22%2C%20%22possum%22%2C%20%22raccoon%22%2C%20%22skunk%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22non-insect%20invertebrates%22%3A%20%5B%22crab%22%2C%20%22lobster%22%2C%20%22snail%22%2C%20%22spider%22%2C%20%22worm%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22people%22%3A%20%5B%22baby%22%2C%20%22boy%22%2C%20%22girl%22%2C%20%22man%22%2C%20%22woman%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22reptiles%22%3A%20%5B%22crocodile%22%2C%20%22dinosaur%22%2C%20%22lizard%22%2C%20%22snake%22%2C%20%22turtle%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22small%20mammals%22%3A%20%5B%22hamster%22%2C%20%22mouse%22%2C%20%22rabbit%22%2C%20%22shrew%22%2C%20%22squirrel%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22trees%22%3A%20%5B%22maple_tree%22%2C%20%22oak_tree%22%2C%20%22palm_tree%22%2C%20%22pine_tree%22%2C%20%22willow_tree%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22vehicles%201%22%3A%20%5B%22bicycle%22%2C%20%22bus%22%2C%20%22motorcycle%22%2C%20%22pickup_truck%22%2C%20%22train%22%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22vehicles%202%22%3A%20%5B%22lawn_mower%22%2C%20%22rocket%22%2C%20%22streetcar%22%2C%20%22tank%22%2C%20%22tractor%22%5D%2C%0A%20%20%20%20%20%20%20%20%7D%0A%0A%0A%20%20%20%20root%20%3D%20SoftmaxNode(%22root%22)%0A%20%20%20%20for%20superclass%2C%20classes%20in%20superclasses.items()%3A%0A%20%20%20%20%20%20%20%20superclass_node%20%3D%20SoftmaxNode(superclass%2C%20parent%3Droot)%0A%20%20%20%20%20%20%20%20for%20class_name%20in%20classes%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20SoftmaxNode(class_name%2C%20parent%3Dsuperclass_node)%0A%0A%20%20%20%20%23%20Now%20that%20the%20tree%20is%20built%2C%20we%20can%20set%20the%20indexes%0A%20%20%20%20%23%20This%20makes%20the%20tree%20read-only%0A%20%20%20%20root.set_indexes()%0A%20%20%20%20name_to_node_id%20%3D%20%7Bnode.name%3A%20root.node_to_id%5Bnode%5D%20for%20node%20in%20root.leaves%7D%0A%20%20%20%20index_to_node_id%20%3D%20%7B%0A%20%20%20%20%20%20%20%20i%3A%20name_to_node_id%5Bname%5D%20for%20i%2C%20name%20in%20enumerate(train_data.classes)%0A%20%20%20%20%7D%0A%0A%20%20%20%20%23%20Render%20the%20hierarchy%0A%20%20%20%20mo.Html(root.svg())%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20HierarchicalSoftmaxLazyLinear%2C%0A%20%20%20%20%20%20%20%20HierarchicalSoftmaxLoss%2C%0A%20%20%20%20%20%20%20%20RankAccuracyTorchMetric%2C%0A%20%20%20%20%20%20%20%20SoftmaxNode%2C%0A%20%20%20%20%20%20%20%20index_to_node_id%2C%0A%20%20%20%20%20%20%20%20root%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%23%23%20Create%20DataLoaders%20with%20hierarchical%20labels%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(DataLoader%2C%20batch_size%2C%20index_to_node_id%2C%20test_data%2C%20torch%2C%20train_data)%3A%0A%20%20%20%20class%20HierarchicalDataset(torch.utils.data.Dataset)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20dataset%2C%20index_to_node_id)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.dataset%20%3D%20dataset%0A%20%20%20%20%20%20%20%20%20%20%20%20self.index_to_node_id%20%3D%20index_to_node_id%0A%0A%20%20%20%20%20%20%20%20def%20__getitem__(self%2C%20idx)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20image%2C%20label%20%3D%20self.dataset%5Bidx%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20image%2C%20self.index_to_node_id%5Blabel%5D%0A%0A%20%20%20%20%20%20%20%20def%20__len__(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20len(self.dataset)%0A%0A%20%20%20%20hierarchical_train_loader%20%3D%20DataLoader(HierarchicalDataset(train_data%2C%20index_to_node_id)%2C%20batch_size%3Dbatch_size%2C%20shuffle%3DTrue)%0A%20%20%20%20hierarchical_test_loader%20%3D%20DataLoader(HierarchicalDataset(test_data%2C%20index_to_node_id)%2C%20batch_size%3Dbatch_size%2C%20shuffle%3DFalse)%0A%20%20%20%20return%20hierarchical_test_loader%2C%20hierarchical_train_loader%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%23%23%20Create%20the%20Hierarchical%20Image%20Classifier%20model%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20BasicImageClassifier%2C%0A%20%20%20%20HierarchicalSoftmaxLazyLinear%2C%0A%20%20%20%20HierarchicalSoftmaxLoss%2C%0A%20%20%20%20RankAccuracyTorchMetric%2C%0A%20%20%20%20SoftmaxNode%2C%0A%20%20%20%20model_body%2C%0A%20%20%20%20nn%2C%0A%20%20%20%20root%2C%0A)%3A%0A%20%20%20%20class%20HierarchicalImageClassifier(BasicImageClassifier)%3A%0A%20%20%20%20%20%20%20%20%23%20Just%20overriding%20the%20init%20-%20keep%20the%20rest%20of%20the%20code%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20root%3A%20SoftmaxNode)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.model%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20model_body()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20HierarchicalSoftmaxLazyLinear(root%3Droot)%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.loss_fn%20%3D%20HierarchicalSoftmaxLoss(root)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.metrics%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20RankAccuracyTorchMetric(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20root%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%7B1%3A%20%22superclass_accuracy%22%2C%202%3A%20%22class_accuracy%22%7D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20self.root%20%3D%20root%0A%0A%20%20%20%20hierarchical_model%20%3D%20HierarchicalImageClassifier(root)%20%20%20%20%20%20%20%20%0A%20%20%20%20hierarchical_model%0A%20%20%20%20return%20(hierarchical_model%2C)%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20CSVLogger%2C%0A%20%20%20%20L%2C%0A%20%20%20%20epochs%2C%0A%20%20%20%20hierarchical_model%2C%0A%20%20%20%20hierarchical_test_loader%2C%0A%20%20%20%20hierarchical_train_loader%2C%0A)%3A%0A%20%20%20%20hierarchical_logger%20%3D%20CSVLogger(save_dir%3D%22lightning_logs%22%2C%20name%3D%22hierarchical_model%22)%0A%20%20%20%20hierarchical_trainer%20%3D%20L.Trainer(max_epochs%3Depochs%2C%20accelerator%3D%22auto%22%2C%20enable_checkpointing%3DFalse%2C%20logger%3Dhierarchical_logger)%0A%20%20%20%20hierarchical_trainer.fit(hierarchical_model%2C%20train_dataloaders%3Dhierarchical_train_loader%2C%20val_dataloaders%3Dhierarchical_test_loader)%0A%20%20%20%20return%20(hierarchical_logger%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%23%23%23%20Plot%20the%20validation%20results%20at%20both%20the%20superclass%20and%20the%20class%20levels%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(Path%2C%20go%2C%20hierarchical_logger%2C%20pd)%3A%0A%20%20%20%20hierarchical_df%20%3D%20pd.read_csv(Path(hierarchical_logger.log_dir)%20%2F%20%22metrics.csv%22)%0A%20%20%20%20hierarchical_df%20%3D%20hierarchical_df.dropna(subset%3D%5B%22val_class_accuracy%22%5D)%0A%20%20%20%20hierarchical_fig%20%3D%20go.Figure()%0A%20%20%20%20hierarchical_fig.add_trace(go.Scatter(x%3Dhierarchical_df%5B%22epoch%22%5D%2C%20y%3Dhierarchical_df%5B%22val_superclass_accuracy%22%5D%2C%20mode%3D'lines'%2C%20name%3D'superclass'))%0A%20%20%20%20hierarchical_fig.add_trace(go.Scatter(x%3Dhierarchical_df%5B%22epoch%22%5D%2C%20y%3Dhierarchical_df%5B%22val_class_accuracy%22%5D%2C%20mode%3D'lines'%2C%20name%3D'class'))%0A%20%20%20%20hierarchical_fig.update_layout(%0A%20%20%20%20%20%20%20%20xaxis_title%3D%22Epochs%22%2C%0A%20%20%20%20%20%20%20%20yaxis_title%3D%22Accuracy%22%2C%0A%20%20%20%20)%0A%20%20%20%20hierarchical_fig%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
e3d516eaee59ebf9049f23e8f11497c3cbc1c986c7ffc43cc095943849bf84c7