5
5
import cv2
6
6
import numpy as np
7
7
import rasterio
8
+ from image_fragment .fragment import Fragment , ImageFragment
8
9
from osgeo import gdal , osr
9
10
from rasterio .features import shapes
10
11
from rasterio .io import BufferedDatasetWriter , DatasetWriter
11
12
from shapely .geometry import shape
12
13
14
+ from gtkit .imutils import get_pixel_resolution , get_affine_transform
13
15
from gtkit .mesh import create_mesh_using_img_param
14
16
15
17
18
+ class StitchNSplit :
19
+ """
20
+ Class for splitting images into smaller fragments.
21
+
22
+ Attributes:
23
+ split_size (tuple): Size of the fragments to split the image into.
24
+ img_size (tuple): Size of the original image.
25
+ image_fragment (ImageFragment): Image fragment object for managing fragments.
26
+ """
27
+
28
+ def __init__ (self , split_size : tuple , img_size : tuple ):
29
+ """
30
+ Initialize the Split class.
31
+
32
+ Parameters:
33
+ split_size (tuple): Size of the fragments to split the image into.
34
+ img_size (tuple): Size of the original image.
35
+
36
+ Raises:
37
+ ValueError: If the split size is greater than the image size.
38
+ """
39
+
40
+ if split_size [0 ] > img_size [0 ] or split_size [1 ] > img_size [1 ]:
41
+ raise ValueError (
42
+ "Size to Split Can't Be Greater than Image, Given {},"
43
+ " Expected <= {}" .format (split_size , (img_size [0 ], img_size [1 ]))
44
+ )
45
+ self .split_size = split_size
46
+ self .img_size = img_size
47
+
48
+ self .image_fragment = ImageFragment .image_fragment_3d (
49
+ fragment_size = self .split_size , org_size = self .img_size
50
+ )
51
+
52
+ def __len__ (self ):
53
+ """
54
+ Get the number of fragments.
55
+
56
+ Returns:
57
+ int: Number of fragments.
58
+ """
59
+
60
+ return len (self .image_fragment .collection )
61
+
62
+ def __getitem__ (self , index ):
63
+ """
64
+ Get a fragment by index.
65
+
66
+ Parameters:
67
+ index (int): Index of the fragment.
68
+
69
+ Returns:
70
+ tuple: Index and the corresponding fragment.
71
+ """
72
+ return index , self .image_fragment .collection [index ]
73
+
74
+ def split (
75
+ self , image : Union [BufferedDatasetWriter , DatasetWriter ], fragment : Fragment
76
+ ):
77
+ """
78
+ Split the image using a windowing approach.
79
+
80
+ Parameters:
81
+ image (rasterio.io.DatasetReader): Input image dataset reader object.
82
+ fragment (Fragment): Fragment object specifying the region of interest.
83
+
84
+ Returns:
85
+ tuple: A tuple containing the extracted image data and additional keyword arguments.
86
+ """
87
+ raise NotImplementedError
88
+
89
+ def stitch (self , image : np .ndarray , stitched_image : np .ndarray , fragment : Fragment ):
90
+ """
91
+ Stitch an image fragment onto a larger stitched image.
92
+
93
+ This method transfers the data from the provided image fragment onto the specified location
94
+ in the larger stitched image.
95
+
96
+ Parameters:
97
+ image (np.ndarray): The image fragment data to be stitched onto the larger image.
98
+ stitched_image (np.ndarray): The larger stitched image onto which the fragment will be stitched.
99
+ fragment (Fragment): The fragment specifying the region in the larger image where the fragment will be placed.
100
+
101
+ Returns:
102
+ np.ndarray: The stitched image with the fragment transferred onto it.
103
+ """
104
+ raise NotImplementedError
105
+
106
+
107
+ class StitchNSplitGeo (StitchNSplit ):
108
+ """
109
+ Subclass of Split specialized for geospatial image splitting.
110
+ """
111
+
112
+ def __init__ (self , split_size : tuple , img_size : tuple ):
113
+ """
114
+ Initialize the SplitGeo class.
115
+
116
+ Parameters:
117
+ split_size (tuple): Size of the fragments to split the image into.
118
+ img_size (tuple): Size of the original image.
119
+ """
120
+ super ().__init__ (split_size , img_size )
121
+
122
+ def split (
123
+ self , image : Union [BufferedDatasetWriter , DatasetWriter ], fragment : Fragment
124
+ ) -> (np .ndarray , dict ):
125
+ """
126
+ Internal method to extract data from a fragment of a geospatial image.
127
+
128
+ Parameters:
129
+ image (rasterio.io.DatasetReader): Input image dataset reader object.
130
+ fragment (Fragment): Fragment object specifying the region of interest.
131
+
132
+ Returns:
133
+ tuple: A tuple containing the extracted image data and additional keyword arguments.
134
+ """
135
+
136
+ split_image = image .read (window = fragment .position )
137
+
138
+ kwargs_split_image = image .meta .copy ()
139
+ kwargs_split_image .update (
140
+ {
141
+ "height" : self .split_size [0 ],
142
+ "width" : self .split_size [1 ],
143
+ "transform" : image .window_transform (fragment .position ),
144
+ }
145
+ )
146
+
147
+ return split_image .swapaxes (0 , 1 ).swapaxes (1 , 2 ), kwargs_split_image
148
+
149
+ def stitch (self , image : np .ndarray , stitched_image : np .ndarray , fragment : Fragment ):
150
+ """
151
+ Stitch an image fragment onto a larger stitched image.
152
+
153
+ This method transfers the data from the provided image fragment onto the specified location
154
+ in the larger stitched image.
155
+
156
+ Parameters:
157
+ image (np.ndarray): The image fragment data to be stitched onto the larger image.
158
+ stitched_image (np.ndarray): The larger stitched image onto which the fragment will be stitched.
159
+ fragment (Fragment): The fragment specifying the region in the larger image where the fragment will be placed.
160
+
161
+ Returns:
162
+ np.ndarray: The stitched image with the fragment transferred onto it.
163
+ """
164
+ return fragment .transfer_fragment (
165
+ transfer_from = image , transfer_to = stitched_image
166
+ )
167
+
168
+
16
169
@dataclass
17
170
class Bitmap :
18
171
"""
@@ -121,8 +274,9 @@ def geowrite(
121
274
transform (affine.Affine): The affine transformation matrix.
122
275
crs (str)
123
276
"""
277
+ assert image .ndim == 3 , f"Input Image must of shape HxWxC"
278
+ bands = image .shape [- 1 ]
124
279
125
- bands = 1 if image .ndim == 2 else image .shape [- 1 ]
126
280
with rasterio .open (
127
281
save_path ,
128
282
"w" ,
@@ -134,7 +288,7 @@ def geowrite(
134
288
transform = transform ,
135
289
crs = crs ,
136
290
) as dst :
137
- dst .write (image , indexes = 1 )
291
+ dst .write (np . rollaxis ( image , axis = 2 ) )
138
292
139
293
140
294
def georead (
0 commit comments