cleanUrl: "pytorch-ignite-usage"
description: "PyTorch ignite 사용법을 정리합니다."


# from pip
pip install pytorch-ignite

# from conda
conda install ignite -c pytorch

Quick lookup


engine.state.epoch: Number of epochs the engine has completed. Initialized as 0.

engine.state.max_epochs: Number of epochs to run for. Initialized as 1.

engine.state.iteration: Number of iterations the engine has completed. Initialized as 0.

engine.state.output: The output of the process_function defined for Engine.

Trainer / Evaluator

Trainer 만들기

def step(engine, batch):
	# batch 받아서 한 step 수행하는 함수 작성.

trainer = Engine(step)

Example (DCGAN)

# The main function, processing a batch of examples
def step(engine, batch):

    # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
    real, _ = batch
    real =

    # -----------------------------------------------------------
    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))

    # train with real
    output = netD(real)
    errD_real = bce(output, real_labels)
    D_x = output.mean().item()


    # get fake image from generator
    noise = get_noise()
    fake = netG(noise)

    # train with fake
    output = netD(fake.detach())
    errD_fake = bce(output, fake_labels)
    D_G_z1 = output.mean().item()


    # gradient update
    errD = errD_real + errD_fake

    # -----------------------------------------------------------
    # (2) Update G network: maximize log(D(G(z)))

    # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
    output = netD(fake)
    errG = bce(output, real_labels)
    D_G_z2 = output.mean().item()


    # gradient update

    return {
        'errD': errD.item(),
        'errG': errG.item(),
        'D_x': D_x,
        'D_G_z1': D_G_z1,
        'D_G_z2': D_G_z2

Supervised trainer 만들기

trainer = create_supervised_trainer(model, optimzer, loss)

Supervised evaluator 만들기

metrics = {'accuracy': Accuracy(), 'nll': Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=metrics)