001// License: GPL. For details, see LICENSE file. 002package org.openstreetmap.josm.tools; 003 004import java.awt.Dimension; 005import java.awt.geom.Point2D; 006import java.awt.geom.Rectangle2D; 007import java.awt.image.BufferedImage; 008import java.util.HashMap; 009import java.util.HashSet; 010import java.util.Map; 011import java.util.Objects; 012import java.util.Set; 013 014/** 015 * Image warping algorithm. 016 * 017 * Deforms an image geometrically according to a given transformation formula. 018 * @since 11858 019 */ 020public final class ImageWarp { 021 022 private ImageWarp() { 023 // Hide default constructor 024 } 025 026 /** 027 * Transformation that translates the pixel coordinates. 028 */ 029 public interface PointTransform { 030 /** 031 * Translates pixel coordinates. 032 * @param pt pixel coordinates 033 * @return transformed pixel coordinates 034 */ 035 Point2D transform(Point2D pt); 036 } 037 038 /** 039 * Wrapper that optimizes a given {@link ImageWarp.PointTransform}. 040 * 041 * It does so by spanning a grid with certain step size. It will invoke the 042 * potentially expensive master transform only at those grid points and use 043 * bilinear interpolation to approximate transformed values in between. 044 * <p> 045 * For memory optimization, this class assumes that rows are more or less scanned 046 * one-by-one as is done in {@link ImageWarp#warp}. I.e. this transform is <em>not</em> 047 * random access in the y coordinate. 048 */ 049 public static class GridTransform implements ImageWarp.PointTransform { 050 051 private final double stride; 052 private final ImageWarp.PointTransform trfm; 053 054 private final Map<Integer, Map<Integer, Point2D>> cache; 055 056 private final boolean consistencyTest; 057 private final Set<Integer> deletedRows; 058 059 /** 060 * Create a new GridTransform. 061 * @param trfm the master transform, that needs to be optimized 062 * @param stride step size 063 */ 064 public GridTransform(ImageWarp.PointTransform trfm, double stride) { 065 this.trfm = trfm; 066 this.stride = stride; 067 this.cache = new HashMap<>(); 068 this.consistencyTest = Logging.isDebugEnabled(); 069 if (consistencyTest) { 070 deletedRows = new HashSet<>(); 071 } else { 072 deletedRows = null; 073 } 074 } 075 076 @Override 077 public Point2D transform(Point2D pt) { 078 int xIdx = (int) Math.floor(pt.getX() / stride); 079 int yIdx = (int) Math.floor(pt.getY() / stride); 080 double dx = pt.getX() / stride - xIdx; 081 double dy = pt.getY() / stride - yIdx; 082 Point2D value00 = getValue(xIdx, yIdx); 083 Point2D value01 = getValue(xIdx, yIdx + 1); 084 Point2D value10 = getValue(xIdx + 1, yIdx); 085 Point2D value11 = getValue(xIdx + 1, yIdx + 1); 086 double valueX = (value00.getX() * (1-dx) + value10.getX() * dx) * (1-dy) + 087 (value01.getX() * (1-dx) + value11.getX() * dx) * dy; 088 double valueY = (value00.getY() * (1-dx) + value10.getY() * dx) * (1-dy) + 089 (value01.getY() * (1-dx) + value11.getY() * dx) * dy; 090 return new Point2D.Double(valueX, valueY); 091 } 092 093 private Point2D getValue(int xIdx, int yIdx) { 094 return getRow(yIdx).computeIfAbsent(xIdx, k -> trfm.transform(new Point2D.Double(xIdx * stride, yIdx * stride))); 095 } 096 097 private Map<Integer, Point2D> getRow(int yIdx) { 098 cleanUp(yIdx - 3); 099 Map<Integer, Point2D> row = cache.get(yIdx); 100 if (row == null) { 101 row = new HashMap<>(); 102 cache.put(yIdx, row); 103 if (consistencyTest) { 104 // should not create a row that has been deleted before 105 if (deletedRows.contains(yIdx)) throw new AssertionError(); 106 // only ever cache 3 rows at once 107 if (cache.size() > 3) throw new AssertionError(); 108 } 109 } 110 return row; 111 } 112 113 // remove rows from cache that will no longer be used 114 private void cleanUp(int yIdx) { 115 Map<Integer, Point2D> del = cache.remove(yIdx); 116 if (consistencyTest && del != null) { 117 // should delete each row only once 118 if (deletedRows.contains(yIdx)) throw new AssertionError(); 119 deletedRows.add(yIdx); 120 } 121 } 122 } 123 124 /** 125 * Interpolation method. 126 */ 127 public enum Interpolation { 128 /** 129 * Nearest neighbor. 130 * 131 * Simplest possible method. Faster, but not very good quality. 132 */ 133 NEAREST_NEIGHBOR, 134 135 /** 136 * Bilinear. 137 * 138 * Decent quality. 139 */ 140 BILINEAR; 141 } 142 143 /** 144 * Warp an image. 145 * @param srcImg the original image 146 * @param targetDim dimension of the target image 147 * @param invTransform inverse transformation (translates pixel coordinates 148 * of the target image to pixel coordinates of the original image) 149 * @param interpolation the interpolation method 150 * @return the warped image 151 */ 152 public static BufferedImage warp(BufferedImage srcImg, Dimension targetDim, PointTransform invTransform, Interpolation interpolation) { 153 BufferedImage imgTarget = new BufferedImage(targetDim.width, targetDim.height, BufferedImage.TYPE_INT_ARGB); 154 Rectangle2D srcRect = new Rectangle2D.Double(0, 0, srcImg.getWidth(), srcImg.getHeight()); 155 for (int j = 0; j < imgTarget.getHeight(); j++) { 156 for (int i = 0; i < imgTarget.getWidth(); i++) { 157 Point2D srcCoord = invTransform.transform(new Point2D.Double(i, j)); 158 if (srcRect.contains(srcCoord)) { 159 int rgba; 160 switch (interpolation) { 161 case NEAREST_NEIGHBOR: 162 rgba = getColor((int) Math.round(srcCoord.getX()), (int) Math.round(srcCoord.getY()), srcImg); 163 break; 164 case BILINEAR: 165 int x0 = (int) Math.floor(srcCoord.getX()); 166 double dx = srcCoord.getX() - x0; 167 int y0 = (int) Math.floor(srcCoord.getY()); 168 double dy = srcCoord.getY() - y0; 169 int c00 = getColor(x0, y0, srcImg); 170 int c01 = getColor(x0, y0 + 1, srcImg); 171 int c10 = getColor(x0 + 1, y0, srcImg); 172 int c11 = getColor(x0 + 1, y0 + 1, srcImg); 173 rgba = 0; 174 // loop over color components: blue, green, red, alpha 175 for (int ch = 0; ch <= 3; ch++) { 176 int shift = 8 * ch; 177 int chVal = (int) Math.round( 178 (((c00 >> shift) & 0xff) * (1-dx) + ((c10 >> shift) & 0xff) * dx) * (1-dy) + 179 (((c01 >> shift) & 0xff) * (1-dx) + ((c11 >> shift) & 0xff) * dx) * dy); 180 rgba |= chVal << shift; 181 } 182 break; 183 default: 184 throw new AssertionError(Objects.toString(interpolation)); 185 } 186 imgTarget.setRGB(i, j, rgba); 187 } 188 } 189 } 190 return imgTarget; 191 } 192 193 private static int getColor(int x, int y, BufferedImage img) { 194 // border strategy: continue with the color of the outermost pixel, 195 return img.getRGB( 196 Utils.clamp(x, 0, img.getWidth() - 1), 197 Utils.clamp(y, 0, img.getHeight() - 1)); 198 } 199}