Search
🔥

PyTorch ignite 사용법

Installation

# from pip pip install pytorch-ignite # from conda conda install ignite -c pytorch
Shell
복사

Quick lookup

engine

engine.state.epoch: Number of epochs the engine has completed. Initialized as 0.
engine.state.max_epochs: Number of epochs to run for. Initialized as 1.
engine.state.iteration: Number of iterations the engine has completed. Initialized as 0.
engine.state.output: The output of the process_function defined for Engine.

Trainer / Evaluator

Trainer 만들기
def step(engine, batch): # batch 받아서 한 step 수행하는 함수 작성. ... trainer = Engine(step)
Python
복사
Example (DCGAN)
# The main function, processing a batch of examples def step(engine, batch): # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels. real, _ = batch real = real.to(device) # ----------------------------------------------------------- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) netD.zero_grad() # train with real output = netD(real) errD_real = bce(output, real_labels) D_x = output.mean().item() errD_real.backward() # get fake image from generator noise = get_noise() fake = netG(noise) # train with fake output = netD(fake.detach()) errD_fake = bce(output, fake_labels) D_G_z1 = output.mean().item() errD_fake.backward() # gradient update errD = errD_real + errD_fake optimizerD.step() # ----------------------------------------------------------- # (2) Update G network: maximize log(D(G(z))) netG.zero_grad() # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" output = netD(fake) errG = bce(output, real_labels) D_G_z2 = output.mean().item() errG.backward() # gradient update optimizerG.step() return { 'errD': errD.item(), 'errG': errG.item(), 'D_x': D_x, 'D_G_z1': D_G_z1, 'D_G_z2': D_G_z2 }
Python
복사
Supervised trainer 만들기
trainer = create_supervised_trainer(model, optimzer, loss)
Python
복사
Supervised evaluator 만들기
metrics = {'accuracy': Accuracy(), 'nll': Loss(loss)} evaluator = create_supervised_evaluator(model, metrics=metrics)
Python
복사

Quickstart

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator model = Net() train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8) loss = torch.nn.NLLLoss() trainer = create_supervised_trainer(model, optimizer, loss) evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': Accuracy(), 'nll': Loss(loss) }) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(trainer): print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(trainer): evaluator.run(train_loader) metrics = evaluator.state.metrics print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(trainer.state.epoch, metrics['accuracy'], metrics['nll'])) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): evaluator.run(val_loader) metrics = evaluator.state.metrics print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(trainer.state.epoch, metrics['accuracy'], metrics['nll'])) trainer.run(train_loader, max_epochs=100)
Python
복사

Concept

Essence of the framework is the class engine, an abstraction that loops a given number of times over provided data, executes a processing function and returns a result:
while epoch < max_epochs: # run once on data for batch in data: output = process_function(batch)
Python
복사
Thus, a model trainer is simply an engine that loops multiple times over the training dataset and updates model parameters. Similarly, model evaluation can be done with an engine that runs a single time over the validation dataset and computes metrics. For example, model trainer for a supervised task:
def update_model(trainer, batch): model.train() optimizer.zero_grad() x, y = prepare_batch(batch) y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() optimizer.step() return loss.item() trainer = Engine(update_model) trainer.run(data, max_epochs=100)
Python
복사

Event and Handlers

To improves the Engine's flexibility, an event system is introduced that facilitates interaction on each step of the run:
engine is started/completed
epoch is started/completed
batch iteration is started/completed
Complete list of events can be found at Events.
Thus, user can execute a custom code as an event handler. Let us consider in more detail what happens when run() is called:
fire_event(Events.STARTED) # Engine is started. while epoch < max_epochs: fire_event(Event.EPOCH_STARTED) # Epoch is started. # run once on data for batch in data: fire_event(Event.ITERATION_STARTED) # Batch iteration is started. output = process_function(batch) fire_event(Event.ITERATION_COMPLETED) # Batch iteration is completed. fire_event(Event.EPOCH_COMPLETED) # Epoch is completed. fire_event(Events.COMPLETED) # Engine is completed.
Python
복사