import {
  createAsyncIterable,
  TransformStream,
  createPromise,
  isToolCallAssistantMessage,
  functionToolToModelTool,
} from '../utils'
import {
  BaseChatModelInput,
  BaseDoStreamOutputChunk,
  DoStreamOutput,
  DoGenerateOutput,
  SimpleChatModel,
  ToolCall,
  ChatModelMessage,
  AsyncIterableReadableStream,
  Usage,
  ToolCallAssistantMessage,
  ModelTool,
  FunctionTool,
} from '../type'

type ReactModelInput = ReactProps &
Omit<BaseChatModelInput, 'tools'> & {
  tools?: Array<ModelTool | FunctionTool>
  topP?: number
  toolChoice?: 'none' | 'auto' | 'custom'
}

interface IOnStepFinish {
  messages: Array<ChatModelMessage>
  text?: string
  toolCall?: ToolCall
  toolResult?: unknown
  finishReason?: string
  stepUsage?: Usage
  totalUsage?: Usage
}

interface ReactProps {
  maxSteps?: number
  onStepFinish?: (prop: IOnStepFinish) => unknown
  abortSignal?: AbortSignal // TODO: 实现 abortSignal
}

function processInput(obj: ReactModelInput): [ReactProps, BaseChatModelInput] {
  const { onStepFinish, abortSignal, maxSteps, topP, toolChoice, ...b } = obj

  if (maxSteps != null && maxSteps < 1) {
    throw new Error('`maxSteps` muse be greater than 0.')
  }

  return [
    { onStepFinish, abortSignal, maxSteps },
    {
      ...b,
      tools: b.tools?.map((tool) => {
        if ('fn' in tool) {
          return functionToolToModelTool(tool)
        }
        return tool
      }),
      top_p: topP != null ? topP : b.top_p,
      tool_choice: toolChoice != null ? toolChoice : b.tool_choice,
    },
  ]
}

export class ReactModel {
  constructor(private model: SimpleChatModel) {}

  public async generateText(_input: ReactModelInput): Promise<{
    text: string
    messages: Array<ChatModelMessage>
    usage: Usage
    rawResponses: Array<unknown>
    error?: any
  }> {
    const rawResponses = []
    const totalUsage: Usage = { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 }

    const [{ onStepFinish, maxSteps = 10 }, input] = processInput(_input)

    const doGenerate = () => this.model.doGenerate(input) // 后续代码会直接对 input.messages 原地修改，这里一直用同一个对象就行
    let currentRes = await doGenerate()
    let currentStep = 1
    currentRes.rawResponse && rawResponses.push(currentRes.rawResponse)

    let toolCall: ToolCall | null = null

    // TODO: 一次对话有多个 tool call? 目前没有这种现象，暂时不处理
    while (currentStep < maxSteps && (toolCall = getToolCallFromGenerate(currentRes)) != null) {
      const stepUsage = createSolidUsage(currentRes.usage)
      addToUsage(totalUsage, stepUsage)

      // 当判断需要工具调用时
      try {
        const toolCallResult = await callTool(toolCall) // 调用

        const choice = currentRes.choices[0] // getToolCallFromGenerate 保证了 choice 肯定存在

        await onStepFinish?.({
          finishReason: choice.finish_reason,
          messages: input.messages.slice(),
          text: choice.message.content,
          toolCall,
          toolResult: toolCallResult,
          stepUsage,
          totalUsage: Object.assign({}, totalUsage),
        })

        pushNewMessages(input.messages, choice.message as ToolCallAssistantMessage, toolCallResult) // 用调用结果修改最新的消息

        currentRes = await doGenerate() // 循环对话
        currentRes.rawResponse && rawResponses.push(currentRes.rawResponse)
        currentStep += 1
      } catch (e) {
        return {
          text: '',
          messages: input.messages,
          usage: totalUsage,
          error: e,
          rawResponses,
        }
      }
    }

    const lastChoice = currentRes?.choices?.[0]
    const lastMessage = lastChoice?.message

    const text = lastMessage?.content ?? ''
    const messages = lastMessage ? [...input.messages, lastMessage] : input.messages

    const stepUsage = createSolidUsage(currentRes.usage)
    addToUsage(totalUsage, stepUsage)

    await onStepFinish?.({
      finishReason: lastChoice.finish_reason,
      messages: messages.slice(),
      text,
      toolCall: getToolCallFromGenerate(currentRes),
      toolResult: null,
      stepUsage,
      totalUsage: Object.assign({}, totalUsage),
    })

    return {
      text,
      messages,
      usage: totalUsage,
      rawResponses,
    }
  }

  public async streamText(_input: ReactModelInput): Promise<{
    dataStream: DoStreamOutput
    textStream: AsyncIterableReadableStream<string>
    messages: Promise<Array<ChatModelMessage>>
    usage: Promise<Usage>
    error?: any
  }> {
    const totalUsage: Usage = { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 }

    const [{ onStepFinish, maxSteps = 10 }, input] = processInput(_input)
    const doStream = () => this.model.doStream(input) // 后续代码会直接对 input.messages 原地修改，这里一直用同一个对象就行
    let currentRes = await doStream()
    const currentStep = 1
    let readResult: { message: ToolCallAssistantMessage; usage: Usage } | null = null

    const readCurrentStream = () => {
      const [oldStream, newStream] = currentRes.tee()
      currentRes = createAsyncIterable(oldStream)
      return readFunctionCallStream(newStream)
    }

    // TODO: 一次对话有多个 tool call? 目前没有这种现象，暂时不处理
    // 这里和 generateText 不太一样，除了解析出 toolCall 外，还需要从流中构造出其他完整的信息
    while (currentStep < maxSteps && (readResult = await readCurrentStream()) != null) {
      const { message: assistantMessage, usage: stepUsage } = readResult
      addToUsage(totalUsage, stepUsage)

      // 当判断需要工具调用时
      const toolCall = assistantMessage.tool_calls?.[0] // 这个 toolCall 一定存在
      try {
        const toolCallResult = await callTool(toolCall) // 调用

        await onStepFinish?.({
          finishReason: 'tool_calls',
          messages: input.messages.slice(),
          text: assistantMessage.content,
          toolCall,
          toolResult: toolCallResult,
          stepUsage,
          totalUsage: Object.assign({}, totalUsage),
        })

        pushNewMessages(input.messages, assistantMessage, toolCallResult) // 用调用结果修改最新的消息
        currentRes = await doStream() // 循环对话
      } catch (e) {
        const [s1, s2] = currentRes.tee()
        return {
          messages: Promise.resolve(input.messages),
          dataStream: createAsyncIterable(s1),
          textStream: createAsyncIterable(s2.pipeThrough(new TransformStream({
            transform(chunk, controller) {
              const str = chunk?.choices?.[0]?.delta?.content
              if (typeof str === 'string') controller.enqueue(str)
            },
          }),),),
          usage: Promise.resolve(totalUsage),
          error: e,
        }
      }
    }

    /**
     * 最后返回时，有几种情况：
     * 1. 没超 maxStep 无工具调用
     * 2. 超了 maxStep 有工具调用
     * 3. 超了 maxStep 无工具调用
     * never. 没超 maxStep，有工具调用，这时候会进到上面的 while 循环 block 中处理
     *
     * 其中 1. 3. 可以合并，没有工具调用就应该直接返回，无论 maxStep
     *
     * 所以合并为：
     * 1. 无工具调用
     * 2. 有工具调用，但是超过 maxStep
     *
     * 这两种情况都没进到 while 循环 block 中处理
     * 我们需要 a. 塞 message b. 算 Usage c. 调用 onStepFinish
     */

    readResult = await readCurrentStream()

    if (readResult) {
      // 情况 2 有工具调用，但是超过 maxStep
      const { message, usage } = readResult
      addToUsage(totalUsage, usage)

      const messages = [...input.messages, message]

      onStepFinish({
        messages: messages.slice(),
        finishReason: 'tool_call',
        stepUsage: usage,
        text: message.content,
        toolCall: message.tool_calls[0],
        totalUsage: Object.assign({}, totalUsage),
      })

      const [s1, s2] = currentRes.tee()
      return {
        messages: Promise.resolve([...input.messages, message]),
        dataStream: createAsyncIterable(s1),
        textStream: createAsyncIterable(s2.pipeThrough(new TransformStream({
          transform(chunk, controller) {
            const str = chunk?.choices?.[0]?.delta?.content
            if (typeof str === 'string') controller.enqueue(str)
          },
        }),),),
        usage: Promise.resolve(totalUsage),
      }
    }
    // 情况 1 无工具调用
    const messagePromise = createPromise<Array<ChatModelMessage>>()
    const usagePromise = createPromise<Usage>()

    const message: ChatModelMessage = {
      role: 'assistant',
      content: '',
    }
    let finishReason = ''
    const stepUsage: Usage = { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 }

    const originStream = currentRes.pipeThrough(new TransformStream({
      transform(chunk, controller) {
        // 不改变 chunk 内容，只是拦截下内容拼最后的结果
        const content = chunk?.choices?.[0]?.delta?.content
        if (typeof content === 'string') {
          message.content += content
        }

        const reason = chunk?.choices?.[0]?.finish_reason
        if (reason) finishReason = reason

        // TODO: 不同大模型的 stream usage 格式不一样，后续可能要调整.
        // hunyuan 每个 chunk 都会有 usage，逐步增加，以最后一个的为准;
        // zhipu 最后 chunk 会有 usage;
        if (chunk?.usage?.completion_tokens) stepUsage.completion_tokens = chunk.usage.completion_tokens
        if (chunk?.usage?.prompt_tokens) stepUsage.prompt_tokens = chunk.usage.prompt_tokens
        if (chunk?.usage?.total_tokens) stepUsage.total_tokens = chunk.usage.total_tokens

        controller.enqueue(chunk)
      },
      flush() {
        messagePromise.res([...input.messages, message])
        addToUsage(totalUsage, stepUsage)
        usagePromise.res(Object.assign({}, totalUsage))
        onStepFinish?.({
          messages: [...input.messages, message],
          finishReason,
          text: message.content,
          stepUsage,
          totalUsage: Object.assign({}, totalUsage),
        })
      },
    }),)

    const [s1, s2] = originStream.tee()

    return {
      messages: messagePromise.promise,
      dataStream: createAsyncIterable(s1),
      textStream: createAsyncIterable(s2.pipeThrough(new TransformStream({
        transform(chunk, controller) {
          const content = chunk?.choices?.[0]?.delta?.content
          if (typeof content === 'string') {
            controller.enqueue(content)
          }
        },
      }),),),
      usage: usagePromise.promise,
    }
  }
}

function getToolCallFromGenerate(output: DoGenerateOutput) {
  const choice = output?.choices?.[0]

  if (!choice) return null

  const { finish_reason, message } = choice

  if (finish_reason !== 'tool_calls') return null
  if (!message) return null
  if (!isToolCallAssistantMessage(message)) return null

  return message.tool_calls[0]
}

function pushNewMessages(
  messages: Array<ChatModelMessage>,
  assistantMessage: ToolCallAssistantMessage,
  toolCallResult: unknown,
) {
  messages.push(assistantMessage, {
    role: 'tool',
    tool_call_id: assistantMessage.tool_calls[0].id,
    content: JSON.stringify(toolCallResult),
  })
}

async function readFunctionCallStream(stream: ReadableStream<BaseDoStreamOutputChunk>,): Promise<{ message: ToolCallAssistantMessage; usage: Usage } | null> {
  const stepUsage: Usage = { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 }
  const aStream = createAsyncIterable(stream)

  const retToolCall: ToolCall = {
    id: '',
    function: {
      name: '',
      arguments: '',
    },
    type: '',
  }

  const retMessage: ToolCallAssistantMessage = {
    role: 'assistant',
    content: '',
    tool_calls: [retToolCall],
  }

  for await (const chunk of aStream) {
    const choice = chunk?.choices[0]
    if (!choice) return null

    const { finish_reason, delta } = choice

    if (finish_reason !== 'tool_calls') return null
    if (!delta) continue

    if (delta.content) retMessage.content += delta.content

    if (!('tool_calls' in delta)) continue
    const toolCall = delta?.tool_calls?.[0]
    if (toolCall?.id) retToolCall.id = toolCall.id
    if (toolCall?.type) retToolCall.type = toolCall.type
    if (toolCall?.function?.name) retToolCall.function.name = toolCall.function.name
    if (toolCall?.function?.arguments) retToolCall.function.arguments += toolCall.function.arguments

    // TODO: 不同大模型的 stream usage 格式不一样，后续可能要调整.
    // hunyuan 每个 chunk 都会有 usage，逐步增加，以最后一个的为准;
    // zhipu 最后 chunk 会有 usage;
    if (chunk?.usage?.completion_tokens) stepUsage.completion_tokens = chunk.usage.completion_tokens
    if (chunk?.usage?.prompt_tokens) stepUsage.prompt_tokens = chunk.usage.prompt_tokens
    if (chunk?.usage?.total_tokens) stepUsage.total_tokens = chunk.usage.total_tokens
  }

  return {
    message: retMessage,
    usage: stepUsage,
  }
}

export const toolMap = new Map<string, CallableFunction>()

function callTool(toolCall: ToolCall) {
  return toolMap.get(toolCall.function.name)(JSON.parse(toolCall.function.arguments))
}

function createSolidUsage(usage?: Partial<Usage>): Usage {
  return {
    completion_tokens: usage?.completion_tokens ?? 0,
    prompt_tokens: usage?.prompt_tokens ?? 0,
    total_tokens: usage?.total_tokens ?? 0,
  }
}

function addToUsage(targetUsage: Usage, sourceUsage: Usage) {
  targetUsage.completion_tokens += sourceUsage.completion_tokens
  targetUsage.prompt_tokens += sourceUsage.prompt_tokens
  targetUsage.total_tokens += sourceUsage.total_tokens
}
