In this Python tutorial, I will show you how to run an SDXL-based model like RealVisXL v4.0 without consuming excessive GPU memory.
I will use the Optimum Quanto library to replace the parameters of the U-Net component of the stable diffusion model with the quantized parameters.
These quantized parameters will result in low GPU memory consumption with a slight increase in latency.
Install libraries
!pip install optimum-quanto accelerate diffusers
Import RealVisXL v4.0 model
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
import torch
# import RealVisXL_V4.0 model
pipe = StableDiffusionXLPipeline.from_pretrained("SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16)
It is recommended to use the DPM++ 2M Karras sampler with the RealVisXL v4.0 model. Hence, let’s change the sampler.
# add DPM++ 2M Karras sampler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
The Diffusers library has a pretty cool feature to reduce memory consumption by allowing offloading the model weights to the CPU and only loading them to the GPU while using the model for image generation.
To use this feature, we can add the line of code given below.
pipe.enable_model_cpu_offload( )
Parameter Quantization
Now I will use some functions of the Optimum-Quanto library to replace the existing U-Net weights with the quantized weights.
# quantize parameters
quantize(pipe.unet, weights=qfloat8)
freeze(pipe.unet)
Image Generation
Let’s add the prompts and configure the pipeline for image generation.
prompt = """extremely detailed illustration of a steampunk train
at the station, intricate details, perfect environment"""
neg_prompt = "blurred, low quality, text, watermark, jpeg"
# specify image generation parameters
image = pipe(prompt=prompt,
negative_prompt = neg_prompt,
num_inference_steps=35,
guidance_scale=4,
width=1024,
height=1024).images[0]
# display image
image
Similarly, you can optimize GPU memory consumption for any other SDXL-based or Stable Diffusion v1.5-based model.