2025, Dec 19 09:01

Как собрать изображение из патчей в JAX с reshape и transpose

Разбираем сборку изображений из патчей в JAX без циклов Python: только transpose и reshape. Быстрее, меньше памяти, безопаснее на GPU с @jax.jit и стабильнее.

Сборка изображений из массивов патчей — частая задача в конвейерах JAX. Наивный подход — вложенные циклы Python — работает, но теряет в производительности и на ускорителях может приводить к проблемам с памятью. Хорошая новость: когда патчи расположены на регулярной сетке, исходные тензоры можно восстановить, используя лишь transpose и reshape.

Постановка задачи

Мы начинаем с батча изображений и разбиваем их на патчи. Приведённая ниже реконструкция корректна, но неэффективна из‑за циклов на чистом Python по сетке и пространственным осям.

patch_kernel = jnp.ones((PATCH_VEC, N_CH, PH, PW), dtype=jnp.float32)

def extract_tiles(x):
    tile_grid = lax.conv_general_dilated_patches(
        x, (PH, PW), (PH, PW), padding='VALID'
    )
    # возвращаем каналы последними внутри измерения патча
    return jnp.transpose(tile_grid, [0, 2, 3, 1])

# bfrc — батч изображений формы (batch, channels, height, width)
tile_buf = extract_tiles(bfrc)

# V_SPLITS == IMG_HEIGHT // PH
# H_SPLITS == IMG_WIDTH // PW

# Обратите внимание: обе переменные ссылаются на одно и то же значение; ниже используется второе имя
tiles_vh_c_ph_pw = tiles_alias = jnp.reshape(
    tile_buf, (V_SPLITS, H_SPLITS, N_CH, PH, PW)
)

recon_img = np.zeros(EXP_SHAPE)

for vi in range(0, tiles_vh_c_ph_pw.shape[0]):
    for hi in range(0, tiles_vh_c_ph_pw.shape[1]):
        for ch_ix in range(0, tiles_vh_c_ph_pw.shape[2]):
            for pr in range(0, tiles_vh_c_ph_pw.shape[3]):
                for pc in range(0, tiles_vh_c_ph_pw.shape[4]):
                    r_idx = vi * PH + pr
                    c_idx = hi * PW + pc
                    recon_img[0, ch_ix, r_idx, c_idx] = tiles_vh_c_ph_pw[vi, hi, ch_ix, pr, pc]

# Это утверждение (assert) выполняется
assert jnp.max(jnp.abs(recon_img - bfrc[0])) == 0

Что здесь на самом деле происходит

Тензор патчей представляет собой регулярную сетку формы (V_SPLITS, H_SPLITS, N_CH, PH, PW). Поскольку патчи заполняют изображение без перекрытий, обратная сборка — это чисто операция перестановки. Никакой арифметики не требуется — лишь переупорядочивание осей и схлопывание/расширение измерений. Строка, где одному и тому же преобразованному тензору присваиваются два разных имени, выглядит как оговорка в именовании, но на поведение это не влияет.

Если входной батч имеет форму (batch, channels, height, width), а V_SPLITS = IMG_HEIGHT // PH и H_SPLITS = IMG_WIDTH // PW, восстановление сводится к последовательности шагов reshape и transpose.

Эффективное решение на JAX

Ниже — сборка, работающая только с размещением данных: без циклов Python, только операции над массивами JAX.

# tiles_vh_c_ph_pw имеет форму (V_SPLITS, H_SPLITS, N_CH, PH, PW)
v_bins, h_bins, n_chan, p_h, p_w = tiles_vh_c_ph_pw.shape

full_h = v_bins * p_h
full_w = h_bins * p_w

# Переносим каналы на последнюю ось внутри каждого блока патча: (V, H, PH, PW, C)
ordered = jnp.transpose(tiles_vh_c_ph_pw, (0, 1, 3, 4, 2))

# Переупорядочиваем размеры сетки в непрерывную раскладку изображения
restored = ordered.reshape(v_bins, h_bins, p_h, p_w, n_chan)
restored = restored.transpose(0, 2, 1, 3, 4)
restored = restored.reshape(full_h, full_w, n_chan)

# Итоговая форма: (1, C, H, W)
reconstructed_batch = jnp.transpose(restored, (2, 0, 1))[jnp.newaxis, ...]

Такой подход дешевле, чем вложенные циклы, потому что полностью остаётся внутри скомпилированных операций массивов JAX. Дополнительно можно обернуть его в @jax.jit для ещё большей скорости.

Почему это важно

Исключив циклы на стороне Python, мы позволяем XLA сливать операции над формами и эффективно исполнять их на устройстве. Это снижает накладные расходы и улучшает работу с памятью, а на практике, как отмечено, обёртка реконструкции в @jax.jit может быть необходима, чтобы избежать OOM на GPU.

Выводы

Когда вы режете изображения на регулярную сетку патчей и затем собираете их обратно, мыслите в терминах алгебры форм: сначала используйте transpose, чтобы собрать вместе нужные оси, затем применяйте reshape для схлопывания или расширения до нужной раскладки. Точно проверяйте формы тензоров — особенно положение каналов и размеры сетки — и отдавайте предпочтение скомпилированным преобразованиям JAX вместо циклов Python. Явно фиксированные формы и константы делают пример воспроизводимым и упрощают рассуждение.