import { useChartCtx } from '@weenat/client/dist/core/charts'
import { MetricWithUnit } from '@weenat/client/dist/enums/UnitChoices'
import { useConvertedValue } from '@weenat/client/dist/hooks'
import { fromColorPathToColor } from '@weenat/theme'
import { isNil } from 'lodash-es'
import { FC, useMemo } from 'react'

const TICK_LABEL_SIZE = 10

interface Tick {
  value: number | string | null
  yOffset: number
}

interface PureAxisProps {
  orientation?: 'left' | 'right'
  height: number
  marginTop?: number
  marginBottom?: number
  xPos?: number
  ticks: Tick[]
  tickPadding?: number
  metric?: MetricWithUnit
}

export const PureAxis: FC<PureAxisProps> = ({
  xPos = 0,
  orientation = 'left',
  tickPadding = 4,
  ticks,
  metric
}) => {
  return (
    <g transform={`translate(${xPos}, 0)`}>
      {ticks.map(({ value, yOffset }) => (
        <g key={value} transform={`translate(0,${yOffset})`}>
          <line x2={orientation === 'left' ? -4 : +4} stroke='currentColor' />
          <text
            fontVariant={'tabular-nums'}
            textAnchor={orientation === 'left' ? 'end' : 'start'}
            fontSize={TICK_LABEL_SIZE}
            fill={!isNil(metric) ? fromColorPathToColor(`metrics.${metric}.500`) : 'black'}
            x={orientation === 'left' ? -4 - tickPadding : +4 + tickPadding}
            y={3}
          >
            {value}
          </text>
        </g>
      ))}
    </g>
  )
}

interface AxisProps {
  position?: 'left' | 'right'
  denormalize?: (val: number) => number
  metric?: MetricWithUnit
  tickPadding?: number
}

const Axis = ({ position = 'left', denormalize, metric, tickPadding }: AxisProps) => {
  const { yScale, height, width, marginLeft, marginRight, marginTop, marginBottom, ticksCount } =
    useChartCtx()
  const { convertValue, formatConvertedValue } = useConvertedValue()

  const axisXPosition = position === 'left' ? marginLeft : (width ?? 0) - (marginRight ?? 0)

  const ticks: Tick[] = useMemo(() => {
    const tickArray = yScale?.ticks(ticksCount)
    return (
      tickArray?.map((tickValue) => {
        const value = denormalize ? denormalize(tickValue) : tickValue
        return {
          value: !isNil(metric)
            ? metric === 'LW_V'
              ? formatConvertedValue({
                  metric,
                  value,
                  displayUnit: false
                })
              : convertValue({
                  metric,
                  value,
                  displayUnit: false
                })
            : value,
          yOffset: yScale?.(tickValue)
        }
      }) ?? []
    )
  }, [yScale, ticksCount, denormalize, metric, formatConvertedValue, convertValue])

  return (
    <PureAxis
      ticks={ticks}
      orientation={position}
      height={height ?? 0}
      xPos={axisXPosition}
      marginBottom={marginBottom}
      marginTop={marginTop}
      tickPadding={tickPadding}
      metric={metric}
    />
  )
}

export default Axis
