import matplotlib.pyplot as plt
from astropy.visualization import simple_norm
from lacosmic import remove_cosmics
from lacosmic.utils import make_cosmic_rays, make_gaussian_sources

# Create synthetic data
shape = (512, 512)
data, error = make_gaussian_sources(shape, seed=0)
norm = simple_norm(data, 'sqrt', percent=99.5)
cr_img = make_cosmic_rays(shape, n_cosmics=200, seed=0)
data2 = data + cr_img

# Remove cosmic rays
clean_img, cr_mask = remove_cosmics(data2, 1, 5, 5, error=error)

# True cosmic ray mask for comparison
true_crmask = cr_img > 0

# Plotting
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))
ax = ax.ravel()
ax[0].imshow(data, norm=norm)
ax[0].set_title('Synthetic Data')
ax[1].imshow(data2, norm=norm)
ax[1].set_title('Synthetic Data with Cosmic Rays')
ax[2].imshow(clean_img, norm=norm)
ax[2].set_title('CR-cleaned Data')
ax[3].imshow(cr_mask)
ax[3].set_title('Cosmic Ray mask')

plt.tight_layout()
plt.show()