diff --git a/lib/grpc_reflection/service/builder.ex b/lib/grpc_reflection/service/builder.ex index a6e2df6..07ace26 100644 --- a/lib/grpc_reflection/service/builder.ex +++ b/lib/grpc_reflection/service/builder.ex @@ -14,7 +14,7 @@ defmodule GrpcReflection.Service.Builder do new_state = process_service(service) State.merge(state, new_state) end) - |> State.group_symbols_by_namespace() + |> State.shrink_cycles() {:ok, tree} end diff --git a/lib/grpc_reflection/service/cycle.ex b/lib/grpc_reflection/service/cycle.ex new file mode 100644 index 0000000..d0dc17d --- /dev/null +++ b/lib/grpc_reflection/service/cycle.ex @@ -0,0 +1,44 @@ +defmodule GrpcReflection.Service.Cycle do + @moduledoc false + + def get_cycles(%GrpcReflection.Service.State{files: files}) do + files + |> Map.values() + |> Enum.reject(&String.ends_with?(&1.name, "Extension.proto")) + |> Map.new(fn file -> {file.name, file.dependency} end) + |> find_cycles() + end + + defp find_cycles(graph) do + graph + |> Map.keys() + |> Enum.reduce({[], []}, fn node, {visited, cycles} -> + dfs(node, graph, visited, [], cycles) + end) + |> elem(1) + |> Enum.map(&Enum.sort/1) + |> Enum.sort() + |> Enum.uniq() + end + + defp dfs(node, graph, visited, path, cycles) do + cond do + node in path -> + cycle = [node | Enum.take_while(path, &(&1 != node))] + {visited, [cycle | cycles]} + + node in visited -> + {visited, cycles} + + true -> + {visited, cycles} = + graph + |> Map.get(node, []) + |> Enum.reduce({[node | visited], cycles}, fn neighbor, {v, c} -> + dfs(neighbor, graph, v, [node | path], c) + end) + + {visited, cycles} + end + end +end diff --git a/lib/grpc_reflection/service/state.ex b/lib/grpc_reflection/service/state.ex index d87b33e..7c76e13 100644 --- a/lib/grpc_reflection/service/state.ex +++ b/lib/grpc_reflection/service/state.ex @@ -116,9 +116,68 @@ defmodule GrpcReflection.Service.State do end end - def group_symbols_by_namespace(%__MODULE__{} = state) do - # group symbols by namespace and combine - # IO.inspect(state) - state + def shrink_cycles(%__MODULE__{} = state) do + new_state = + state + |> GrpcReflection.Service.Cycle.get_cycles() + |> Enum.reduce(state, fn filenames, acc -> + files = filenames |> Enum.map(&acc.files[&1]) |> Enum.reject(&is_nil/1) + + if length(files) < 2 do + acc + else + update_with_combined(acc, combine_file_descriptors(files), filenames) + end + end) + + if new_state == state, do: state, else: shrink_cycles(new_state) + end + + defp update_with_combined(state, combined_file, combined_filenames) do + new_files = + state.files + |> Map.drop(combined_filenames) + |> Map.new(fn {filename, descriptor} -> + if Enum.any?(descriptor.dependency, &(&1 in combined_filenames)) do + updated_deps = (descriptor.dependency -- combined_filenames) ++ [combined_file.name] + {filename, %{descriptor | dependency: Enum.uniq(updated_deps)}} + else + {filename, descriptor} + end + end) + |> Map.put(combined_file.name, combined_file) + + new_symbols = + Map.new(state.symbols, fn {symbol, filename} -> + if filename in combined_filenames do + {symbol, combined_file.name} + else + {symbol, filename} + end + end) + + %{state | files: new_files, symbols: new_symbols} + end + + defp combine_file_descriptors(file_descriptors) do + combined_names = Enum.map(file_descriptors, & &1.name) + canonical_name = Enum.min(combined_names) + + Enum.reduce( + file_descriptors, + %Google.Protobuf.FileDescriptorProto{name: canonical_name}, + fn descriptor, acc -> + %{ + acc + | syntax: acc.syntax || descriptor.syntax, + package: acc.package || descriptor.package, + message_type: Enum.uniq(acc.message_type ++ descriptor.message_type), + service: Enum.uniq(acc.service ++ descriptor.service), + enum_type: Enum.uniq(acc.enum_type ++ descriptor.enum_type), + dependency: Enum.uniq(acc.dependency ++ (descriptor.dependency -- combined_names)), + extension: Enum.uniq(acc.extension ++ descriptor.extension) + } + end + ) end end diff --git a/test/case/recursive_message_test.exs b/test/case/recursive_message_test.exs index 887579e..cbc0dec 100644 --- a/test/case/recursive_message_test.exs +++ b/test/case/recursive_message_test.exs @@ -3,10 +3,6 @@ defmodule GrpcReflection.Case.RecursiveMessageTest do use GrpcCase, service: RecursiveMessage.Service.Service - # Recursive message structures cause infinite loops in the builder's graph traversal. - # Tracked for future fix; protos and tests are in place to validate when resolved. - @moduletag :skip - versions = ["v1", "v1alpha"] for version <- versions do diff --git a/test/case/well_known_types_test.exs b/test/case/well_known_types_test.exs index 0096602..2f38886 100644 --- a/test/case/well_known_types_test.exs +++ b/test/case/well_known_types_test.exs @@ -57,9 +57,13 @@ defmodule GrpcReflection.Case.WellKnownTypesTest do } = response end - # Well-known types contain circular references that cause an infinite loop in our - # reflection tree builder, which grpcurl exposes as a stack overflow. Out of scope - # for now; the reflection API itself is verified via the symbol/filename tests above. + test "reflection graph is traversable using grpcurl", ctx do + ops = GrpcReflection.TestClient.grpcurl_service(ctx) + + assert {:call, "well_known_types.WellKnownTypesService.ProcessWellKnownTypes"} in ops + assert {:call, "well_known_types.WellKnownTypesService.EmptyMethod"} in ops + assert {:service, "well_known_types.WellKnownTypesService"} in ops + end end end end diff --git a/test/service/builder_test.exs b/test/service/builder_test.exs index cedc4f6..001b2f7 100644 --- a/test/service/builder_test.exs +++ b/test/service/builder_test.exs @@ -133,11 +133,10 @@ defmodule GrpcReflection.Service.BuilderTest do test "handles a recursive message structure" do assert {:ok, tree} = Builder.build_reflection_tree([RecursiveMessage.Service.Service]) - assert tree.files |> Map.keys() |> Enum.sort() == [ - "recursive_message.Reply.proto", - "recursive_message.Request.proto", - "recursive_message.Service.proto" - ] + # Request and Reply form a cycle and are merged into one file + file_names = tree.files |> Map.keys() |> Enum.sort() + assert length(file_names) == 2 + assert "recursive_message.Service.proto" in file_names assert tree.symbols |> Map.keys() |> Enum.sort() == [ "recursive_message.Reply",