import React, { FC, useRef, useEffect, useState, MouseEventHandler } from 'react'
import './MLTrainingGrid.css'
import { BitmapImage, BoundingBox } from './BitmapImage'

interface MLTrainingGridProps {
  backgroundImageUrl: string
  mlMask: BitmapImage
  gridScale: number
  width: number
  height: number
  uiScale: number
  zoom: number
  maxZoomFactor: number
  zoomMax: number
  drawMode: boolean
  brushSize: number
  overlayVisible: boolean
  drawColor: string
  setBitmapEmpty: (isEmpty: boolean) => void
}

const MLTrainingGrid: FC<MLTrainingGridProps> = ({ 
  backgroundImageUrl,
  mlMask,
  gridScale,
  width,
  height,
  uiScale,
  zoom,
  maxZoomFactor,
  zoomMax,
  drawMode,
  brushSize,
  overlayVisible,
  drawColor,
  setBitmapEmpty,
 }: MLTrainingGridProps) => {
  const backgroundImageRef = useRef<HTMLImageElement>(null)
  const canvasRef = useRef<HTMLCanvasElement>(null)

  const [renderFlag, setRenderFlag] = useState<boolean>(false)
  const [hoveredGridIndex, setHoveredGridIndex] = useState<number>(-1)

  const gridWidth = Math.ceil(width / gridScale)
  const gridHeight = Math.ceil(height / gridScale)
  const highlightColor = 'rgb(255, 255, 255, 128)'
  const selectedColor = drawColor

  const uiHeight = height / uiScale
  const uiWidth = width / uiScale

  const shouldRender = (): void => {
    setRenderFlag(!renderFlag)
  }

  // Update canvas on bitmap state change.
  useEffect(() => {
    // Set a single grid point to the given color.
    const setGridPoint = (context: CanvasRenderingContext2D, gridScale: number, gridWidth: number, gridPointIndex: number, uiScale: number, color: string): void => {
      context.save()
      context.globalAlpha = overlayVisible ? 0.40 : 0
      const scale = gridScale / uiScale
      const x = gridPointIndex % gridWidth * scale
      const y = Math.floor(gridPointIndex / gridWidth) * scale
      context.fillStyle = color
      context.moveTo(x, y)
      context.beginPath()
      context.lineTo(x + scale, y)
      context.lineTo(x + scale, y + scale)
      context.lineTo(x, y + scale)
      context.lineTo(x, y)
      context.closePath()
      context.fill()
      context.restore()
    }

    const render = (): void => {
      // Get a context reference.
      if (canvasRef.current === undefined || canvasRef.current === null ||
        backgroundImageRef.current === undefined || backgroundImageRef.current === null) {
        console.log('Render called before canvas or image reference is set.')
        return
      }
      const context = canvasRef.current.getContext('2d')
      if (context === null) {
        console.error('Error getting 2D context from canvas.')
        return
      }
  
      // Clear the canvas.
      context.clearRect(0, 0, width, height)
  
      // Increase zoom level exponentially
      const zoomExp = (zoom: number) => {
        const expVal = Math.pow(maxZoomFactor, 1/zoomMax)
        return Math.pow(expVal, zoom)
      }
  
      let zoomFactor = zoomExp(zoom)
      let xTranslation = 0
      let yTranslation = 0
  
      const bb = mlMask.boundingBox(0)
      const bbWidth = (bb.x2 - bb.x1) * gridScale
      const bbHeight = (bb.y2 - bb.y1) * gridScale
  
      // Calculate the translation.
      if (zoomFactor > 1) {
        const bbCenterX = bb.x1 * gridScale + bbWidth / 2
        const bbCenterY = bb.y1 * gridScale + bbHeight / 2
        const scaledWidth = width / zoomFactor
        const scaledHeight = height / zoomFactor
        const bbTopLeftX = bbCenterX - scaledWidth / 2
        const bbTopLeftY = bbCenterY - scaledHeight / 2
        xTranslation = -bbTopLeftX / uiScale * zoomFactor
        yTranslation = -bbTopLeftY / uiScale * zoomFactor
      }
      context.setTransform(zoomFactor, 0, 0, zoomFactor, xTranslation, yTranslation)
  
      // Draw the image into the background.
      context.drawImage(backgroundImageRef.current, 0, 0, uiWidth, uiHeight)
  
      // Fill in squares.
      mlMask.bitmap.enabledIndices().forEach(enabledIndex => {
        setGridPoint(context, gridScale, gridWidth, enabledIndex, uiScale, selectedColor)
      })
  
      // Highlight hovered grid index.
      if (hoveredGridIndex !== -1) {
        mlMask._bitmap.getGrid(hoveredGridIndex, brushSize, gridWidth).forEach(bit => 
          setGridPoint(context, gridScale, gridWidth, bit, uiScale, highlightColor)
        )
      }
      
      // Draw bounding box.
      drawBoundingBox(context, bb, gridScale, uiScale)
      // Draw grid lines.
      drawGridLines(context, width, height, gridScale, uiScale)
  
      setBitmapEmpty(mlMask._bitmap.isEmpty())
    }

    if (canvasRef.current !== undefined && canvasRef.current !== null &&
      backgroundImageRef.current !== undefined && backgroundImageRef.current !== null) {
      render()
    }
  }, [renderFlag, hoveredGridIndex, zoom, overlayVisible, gridScale, width, height, brushSize, 
      maxZoomFactor, zoomMax, uiScale, mlMask, setBitmapEmpty, gridWidth, uiHeight, uiWidth, selectedColor])

  // Draw thin lines separating the grid indices.
  const drawGridLines = (context: CanvasRenderingContext2D, canvasWidth: number, canvasHeight: number, gridScale: number, uiScale: number): void => {
    context.save()
    const scale = gridScale / uiScale
    const wSteps = Math.floor(canvasWidth / scale)
    const hSteps = Math.floor(canvasHeight / scale)

    // Set drawing style.
    context.strokeStyle = 'white'
    context.lineWidth = 0.1
    context.globalAlpha = 0.2

    // Draw horizontal lines.
    for (let i = 0; i < hSteps; i++) {
      context.moveTo(0, i * scale)
      context.lineTo(canvasWidth, i * scale)
    }

    // Draw vertical lines.
    for (let i = 0; i < wSteps; i++) {
      context.moveTo(i * scale, 0)
      context.lineTo(i * scale, canvasHeight)
    }

    context.stroke()
    context.restore()
  }

  const drawBoundingBox = (context: CanvasRenderingContext2D, boundingBox: BoundingBox, gridScale: number, uiScale: number): void => {
    context.save()

    const scale = gridScale / uiScale
    const lineWidth = 0.5
    const a = { x: boundingBox.x1 * scale, y: boundingBox.y1 * scale }
    const b = { x: boundingBox.x2 * scale + scale, y: boundingBox.y2 * scale + scale } // Adding one 'scale' to draw on the other side of the data.

    context.strokeStyle = 'white'
    context.lineWidth = 0.5
    context.moveTo(a.x, a.y)
    context.beginPath()
    context.lineTo(b.x, a.y)
    context.lineTo(b.x, b.y)
    context.lineTo(a.x, b.y)
    context.lineTo(a.x, a.y - (lineWidth / 2)) // -(lineWidth/2) to close the corner properly.
    context.closePath()
    context.stroke()

    context.restore()
  }

  // Get mouse position in screen coordinates within the canvas element.
  const getMousePosition = (event: React.MouseEvent<HTMLCanvasElement>): { x: number, y: number } => {
    if (canvasRef.current === undefined || canvasRef.current === null) {
      return { x: -1, y: -1 }
    }
    const context = canvasRef.current.getContext('2d')
    if (context === null) {
      console.error('Error getting 2D context from canvas.')
      return { x: -1, y: -1 }
    }

    const canvasRect = canvasRef.current.getBoundingClientRect()
    const transformMatrix = context.getTransform()
    const scale: number = 1 / transformMatrix.m11
    const horizontal: number = transformMatrix.m41
    const vertical: number = transformMatrix.m42

    return {
      x: (event.clientX - canvasRect.left - horizontal) * scale,
      y: (event.clientY - canvasRect.top - vertical) * scale
    }
  }

  // Mouse move handler - used for setting the hovered index, so that
  // highlighting may occur.
  const mouseMoveHandler: MouseEventHandler<HTMLCanvasElement> = (event: React.MouseEvent<HTMLCanvasElement>) => {
    const scale = gridScale / uiScale
    const mousePosition = getMousePosition(event)
    const gridX = Math.floor(mousePosition.x / scale)
    const gridY = Math.floor(mousePosition.y / scale)
    const gridIndex = gridY * gridWidth + gridX

    // Sanity check the grid index.
    if (gridIndex < 0 || gridIndex > gridWidth * gridHeight) {
      return // Just bail here
    }

    if (gridIndex !== hoveredGridIndex) {
      // We have moved - to a new grid index. If the mouse button is pressed,
      // we should color in the position.
      if (event.buttons === 1) {
        if (drawMode) {
          mlMask._bitmap.setGrid(gridIndex, brushSize, gridWidth)
        } else {
          mlMask._bitmap.unsetGrid(gridIndex, brushSize, gridWidth)
        }
      }
      setHoveredGridIndex(gridIndex) // Asks for a re-render
    }
  }

  // Mouse click handler for selecting grid positions.
  const mouseDownHandler: MouseEventHandler<HTMLCanvasElement> = (event: React.MouseEvent<HTMLCanvasElement>): void => {
    if (hoveredGridIndex !== -1) {
      if (drawMode) {
        mlMask._bitmap.setGrid(hoveredGridIndex, brushSize, gridWidth)
      } else {
        mlMask._bitmap.unsetGrid(hoveredGridIndex, brushSize, gridWidth)
      }
      shouldRender()
    }
  }

  return (
    <div
      className='grid-container'
      style={{
        width: uiWidth,
        height: uiHeight
      }}
    >
      <img
        className='background-image'
        width={uiWidth}
        height={uiHeight}
        src={backgroundImageUrl}
        ref={backgroundImageRef}
        onLoad={() => shouldRender()}
        alt="background"
      />
      <canvas
        ref={canvasRef}
        className='grid-canvas'
        width={uiWidth}
        height={uiHeight}
        onMouseMove={mouseMoveHandler}
        onMouseDown={mouseDownHandler}
      />
    </div>
  )
}

export { MLTrainingGrid }
