From 9142aca48ff09ed32954eceb3456a255d61945b7 Mon Sep 17 00:00:00 2001 From: Thomas Guillemard Date: Fri, 11 Oct 2019 17:22:24 +0200 Subject: [PATCH] Fix hwopus DecodeInterleaved implementation (#786) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix hwopus DecodeInterleaved implementation Also implement new variants of this api. This should fix #763 * Sample rate shouldn't be hardcoded This fix issues while opening Pokémon Let's Go pause menu. * Apply Ac_K's suggestion about EndianSwap * Address gdkchan's comment * Address Ac_k's comment --- .../Utilities/EndianSwap.cs | 18 +- Ryujinx.HLE/HOS/Font/SharedFontManager.cs | 2 +- .../IHardwareOpusDecoder.cs | 243 ++++++++++++++---- .../Services/Audio/Types/OpusPacketHeader.cs | 24 ++ .../HOS/Services/Sockets/Bsd/IClient.cs | 3 +- 5 files changed, 240 insertions(+), 50 deletions(-) rename {Ryujinx.HLE => Ryujinx.Common}/Utilities/EndianSwap.cs (55%) create mode 100644 Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs diff --git a/Ryujinx.HLE/Utilities/EndianSwap.cs b/Ryujinx.Common/Utilities/EndianSwap.cs similarity index 55% rename from Ryujinx.HLE/Utilities/EndianSwap.cs rename to Ryujinx.Common/Utilities/EndianSwap.cs index df08191a..049570e3 100644 --- a/Ryujinx.HLE/Utilities/EndianSwap.cs +++ b/Ryujinx.Common/Utilities/EndianSwap.cs @@ -1,6 +1,8 @@ -namespace Ryujinx.HLE.Utilities +using System; + +namespace Ryujinx.Common { - static class EndianSwap + public static class EndianSwap { public static ushort Swap16(ushort value) => (ushort)(((value >> 8) & 0xff) | (value << 8)); @@ -13,5 +15,17 @@ ((uintVal << 8) & 0x00ff0000) | ((uintVal << 24) & 0xff000000)); } + + public static uint FromBigEndianToPlatformEndian(uint value) + { + uint result = value; + + if (BitConverter.IsLittleEndian) + { + result = (uint)EndianSwap.Swap32((int)result); + } + + return result; + } } } diff --git a/Ryujinx.HLE/HOS/Font/SharedFontManager.cs b/Ryujinx.HLE/HOS/Font/SharedFontManager.cs index dfb87f3c..8a936dbf 100644 --- a/Ryujinx.HLE/HOS/Font/SharedFontManager.cs +++ b/Ryujinx.HLE/HOS/Font/SharedFontManager.cs @@ -1,9 +1,9 @@ using LibHac.Fs; using LibHac.Fs.NcaUtils; +using Ryujinx.Common; using Ryujinx.HLE.FileSystem; using Ryujinx.HLE.FileSystem.Content; using Ryujinx.HLE.Resource; -using Ryujinx.HLE.Utilities; using System.Collections.Generic; using System.IO; using static Ryujinx.HLE.Utilities.FontUtils; diff --git a/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs b/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs index e23398df..079f2ae7 100644 --- a/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs +++ b/Ryujinx.HLE/HOS/Services/Audio/HardwareOpusDecoderManager/IHardwareOpusDecoder.cs @@ -1,13 +1,18 @@ +using Concentus; +using Concentus.Enums; using Concentus.Structs; +using Ryujinx.HLE.HOS.Services.Audio.Types; +using System; +using System.IO; +using System.Runtime.InteropServices; namespace Ryujinx.HLE.HOS.Services.Audio.HardwareOpusDecoderManager { class IHardwareOpusDecoder : IpcService { - private const int FixedSampleRate = 48000; - - private int _sampleRate; - private int _channelsCount; + private int _sampleRate; + private int _channelsCount; + private bool _reset; private OpusDecoder _decoder; @@ -15,65 +20,211 @@ namespace Ryujinx.HLE.HOS.Services.Audio.HardwareOpusDecoderManager { _sampleRate = sampleRate; _channelsCount = channelsCount; + _reset = false; - _decoder = new OpusDecoder(FixedSampleRate, channelsCount); + _decoder = new OpusDecoder(sampleRate, channelsCount); } - [Command(0)] - // DecodeInterleaved(buffer) -> (u32, u32, buffer) - public ResultCode DecodeInterleaved(ServiceCtx context) + private ResultCode GetPacketNumSamples(out int numSamples, byte[] packet) { - long inPosition = context.Request.SendBuff[0].Position; - long inSize = context.Request.SendBuff[0].Size; + int result = OpusPacketInfo.GetNumSamples(_decoder, packet, 0, packet.Length); - if (inSize < 8) + numSamples = result; + + if (result == OpusError.OPUS_INVALID_PACKET) { return ResultCode.OpusInvalidInput; } - - long outPosition = context.Request.ReceiveBuff[0].Position; - long outSize = context.Request.ReceiveBuff[0].Size; - - byte[] opusData = context.Memory.ReadBytes(inPosition, inSize); - - int processed = ((opusData[0] << 24) | - (opusData[1] << 16) | - (opusData[2] << 8) | - (opusData[3] << 0)) + 8; - - if ((uint)processed > (ulong)inSize) + else if (result == OpusError.OPUS_BAD_ARG) { return ResultCode.OpusInvalidInput; } - short[] pcm = new short[outSize / 2]; - - int frameSize = pcm.Length / (_channelsCount * 2); - - int samples = _decoder.Decode(opusData, 0, opusData.Length, pcm, 0, frameSize); - - foreach (short sample in pcm) - { - context.Memory.WriteInt16(outPosition, sample); - - outPosition += 2; - } - - context.ResponseData.Write(processed); - context.ResponseData.Write(samples); - return ResultCode.Success; } - [Command(4)] - // DecodeInterleavedWithPerf(buffer) -> (u32, u32, u64, buffer) - public ResultCode DecodeInterleavedWithPerf(ServiceCtx context) + private ResultCode DecodeInterleavedInternal(BinaryReader input, out short[] outPcmData, long outputSize, out uint outConsumed, out int outSamples) { - ResultCode result = DecodeInterleaved(context); + outPcmData = null; + outConsumed = 0; + outSamples = 0; - // TODO: Figure out what this value is. - // According to switchbrew, it is now used. - context.ResponseData.Write(0L); + long streamSize = input.BaseStream.Length; + + if (streamSize < Marshal.SizeOf()) + { + return ResultCode.OpusInvalidInput; + } + + OpusPacketHeader header = OpusPacketHeader.FromStream(input); + + uint totalSize = header.length + (uint)Marshal.SizeOf(); + + if (totalSize > streamSize) + { + return ResultCode.OpusInvalidInput; + } + + byte[] opusData = input.ReadBytes((int)header.length); + + ResultCode result = GetPacketNumSamples(out int numSamples, opusData); + + if (result == ResultCode.Success) + { + if ((uint)numSamples * (uint)_channelsCount * sizeof(short) > outputSize) + { + return ResultCode.OpusInvalidInput; + } + + outPcmData = new short[numSamples * _channelsCount]; + + if (_reset) + { + _reset = false; + + _decoder.ResetState(); + } + + try + { + outSamples = _decoder.Decode(opusData, 0, opusData.Length, outPcmData, 0, outPcmData.Length / _channelsCount); + outConsumed = totalSize; + } + catch (OpusException) + { + // TODO: as OpusException doesn't provide us the exact error code, this is kind of inaccurate in some cases... + return ResultCode.OpusInvalidInput; + } + } + + return ResultCode.Success; + } + + [Command(0)] + // DecodeInterleaved(buffer) -> (u32, u32, buffer) + public ResultCode DecodeInterleavedOriginal(ServiceCtx context) + { + ResultCode result; + + long inPosition = context.Request.SendBuff[0].Position; + long inSize = context.Request.SendBuff[0].Size; + long outputPosition = context.Request.ReceiveBuff[0].Position; + long outputSize = context.Request.ReceiveBuff[0].Size; + + using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize)))) + { + result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples); + + if (result == ResultCode.Success) + { + byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)]; + Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length); + context.Memory.WriteBytes(outputPosition, pcmDataBytes); + + context.ResponseData.Write(outConsumed); + context.ResponseData.Write(outSamples); + } + } + + return result; + } + + [Command(4)] // 6.0.0+ + // DecodeInterleavedWithPerfOld(buffer) -> (u32, u32, u64, buffer) + public ResultCode DecodeInterleavedWithPerfOld(ServiceCtx context) + { + ResultCode result; + + long inPosition = context.Request.SendBuff[0].Position; + long inSize = context.Request.SendBuff[0].Size; + long outputPosition = context.Request.ReceiveBuff[0].Position; + long outputSize = context.Request.ReceiveBuff[0].Size; + + using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize)))) + { + result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples); + + if (result == ResultCode.Success) + { + byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)]; + Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length); + context.Memory.WriteBytes(outputPosition, pcmDataBytes); + + context.ResponseData.Write(outConsumed); + context.ResponseData.Write(outSamples); + + // This is the time the DSP took to process the request, TODO: fill this. + context.ResponseData.Write(0); + } + } + + return result; + } + + [Command(6)] // 6.0.0+ + // DecodeInterleavedOld(bool reset, buffer) -> (u32, u32, u64, buffer) + public ResultCode DecodeInterleavedOld(ServiceCtx context) + { + ResultCode result; + + _reset = context.RequestData.ReadBoolean(); + + long inPosition = context.Request.SendBuff[0].Position; + long inSize = context.Request.SendBuff[0].Size; + long outputPosition = context.Request.ReceiveBuff[0].Position; + long outputSize = context.Request.ReceiveBuff[0].Size; + + using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize)))) + { + result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples); + + if (result == ResultCode.Success) + { + byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)]; + Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length); + context.Memory.WriteBytes(outputPosition, pcmDataBytes); + + context.ResponseData.Write(outConsumed); + context.ResponseData.Write(outSamples); + + // This is the time the DSP took to process the request, TODO: fill this. + context.ResponseData.Write(0); + } + } + + return result; + } + + [Command(8)] // 7.0.0+ + // DecodeInterleaved(bool reset, buffer) -> (u32, u32, u64, buffer) + public ResultCode DecodeInterleaved(ServiceCtx context) + { + ResultCode result; + + _reset = context.RequestData.ReadBoolean(); + + long inPosition = context.Request.SendBuff[0].Position; + long inSize = context.Request.SendBuff[0].Size; + long outputPosition = context.Request.ReceiveBuff[0].Position; + long outputSize = context.Request.ReceiveBuff[0].Size; + + using (BinaryReader inputStream = new BinaryReader(new MemoryStream(context.Memory.ReadBytes(inPosition, inSize)))) + { + result = DecodeInterleavedInternal(inputStream, out short[] outPcmData, outputSize, out uint outConsumed, out int outSamples); + + if (result == ResultCode.Success) + { + byte[] pcmDataBytes = new byte[outPcmData.Length * sizeof(short)]; + Buffer.BlockCopy(outPcmData, 0, pcmDataBytes, 0, pcmDataBytes.Length); + context.Memory.WriteBytes(outputPosition, pcmDataBytes); + + context.ResponseData.Write(outConsumed); + context.ResponseData.Write(outSamples); + + // This is the time the DSP took to process the request, TODO: fill this. + context.ResponseData.Write(0); + } + } return result; } diff --git a/Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs b/Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs new file mode 100644 index 00000000..bb4b6d16 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Audio/Types/OpusPacketHeader.cs @@ -0,0 +1,24 @@ +using Ryujinx.Common; +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace Ryujinx.HLE.HOS.Services.Audio.Types +{ + [StructLayout(LayoutKind.Sequential)] + struct OpusPacketHeader + { + public uint length; + public uint finalRange; + + public static OpusPacketHeader FromStream(BinaryReader reader) + { + OpusPacketHeader header = reader.ReadStruct(); + + header.length = EndianSwap.FromBigEndianToPlatformEndian(header.length); + header.finalRange = EndianSwap.FromBigEndianToPlatformEndian(header.finalRange); + + return header; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs index 3a02e06c..7db8066a 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs @@ -1,4 +1,5 @@ -using Ryujinx.Common.Logging; +using Ryujinx.Common; +using Ryujinx.Common.Logging; using Ryujinx.HLE.Utilities; using System.Collections.Generic; using System.Net;