Installation
# from pip
pip install pytorch-ignite
# from conda
conda install ignite -c pytorch
Shell
복사
Quick lookup
engine
engine.state.epoch: Number of epochs the engine has completed. Initialized as 0.
engine.state.max_epochs: Number of epochs to run for. Initialized as 1.
engine.state.iteration: Number of iterations the engine has completed. Initialized as 0.
engine.state.output: The output of the process_function defined for Engine.
Trainer / Evaluator
Trainer 만들기
def step(engine, batch):
# batch 받아서 한 step 수행하는 함수 작성.
...
trainer = Engine(step)
Python
복사
Example (DCGAN)
# The main function, processing a batch of examples
def step(engine, batch):
# unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
real, _ = batch
real = real.to(device)
# -----------------------------------------------------------
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()
# train with real
output = netD(real)
errD_real = bce(output, real_labels)
D_x = output.mean().item()
errD_real.backward()
# get fake image from generator
noise = get_noise()
fake = netG(noise)
# train with fake
output = netD(fake.detach())
errD_fake = bce(output, fake_labels)
D_G_z1 = output.mean().item()
errD_fake.backward()
# gradient update
errD = errD_real + errD_fake
optimizerD.step()
# -----------------------------------------------------------
# (2) Update G network: maximize log(D(G(z)))
netG.zero_grad()
# Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
output = netD(fake)
errG = bce(output, real_labels)
D_G_z2 = output.mean().item()
errG.backward()
# gradient update
optimizerG.step()
return {
'errD': errD.item(),
'errG': errG.item(),
'D_x': D_x,
'D_G_z1': D_G_z1,
'D_G_z2': D_G_z2
}
Python
복사
Supervised trainer 만들기
trainer = create_supervised_trainer(model, optimzer, loss)
Python
복사
Supervised evaluator 만들기
metrics = {'accuracy': Accuracy(), 'nll': Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=metrics)
Python
복사
Quickstart
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()
trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
metrics={
'accuracy': Accuracy(),
'nll': Loss(loss)
})
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics['accuracy'], metrics['nll']))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics['accuracy'], metrics['nll']))
trainer.run(train_loader, max_epochs=100)
Python
복사
Concept
Essence of the framework is the class engine, an abstraction that loops a given number of times over provided data, executes a processing function and returns a result:
while epoch < max_epochs:
# run once on data
for batch in data:
output = process_function(batch)
Python
복사
Thus, a model trainer is simply an engine that loops multiple times over the training dataset and updates model parameters. Similarly, model evaluation can be done with an engine that runs a single time over the validation dataset and computes metrics. For example, model trainer for a supervised task:
def update_model(trainer, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(update_model)
trainer.run(data, max_epochs=100)
Python
복사
Event and Handlers
To improves the Engine's flexibility, an event system is introduced that facilitates interaction on each step of the run:
•
engine is started/completed
•
epoch is started/completed
•
batch iteration is started/completed
Complete list of events can be found at Events.
Thus, user can execute a custom code as an event handler. Let us consider in more detail what happens when run() is called:
fire_event(Events.STARTED) # Engine is started.
while epoch < max_epochs:
fire_event(Event.EPOCH_STARTED) # Epoch is started.
# run once on data
for batch in data:
fire_event(Event.ITERATION_STARTED) # Batch iteration is started.
output = process_function(batch)
fire_event(Event.ITERATION_COMPLETED) # Batch iteration is completed.
fire_event(Event.EPOCH_COMPLETED) # Epoch is completed.
fire_event(Events.COMPLETED) # Engine is completed.
Python
복사